Commit 7d1a83a9 authored by aiss's avatar aiss
Browse files

push Deepspeed 0.6.3 rocm version

parent ab5534fc
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <cooperative_groups.h> #ifdef __HIP_PLATFORM_HCC__
#include <curand_kernel.h> #include <hip/hip_cooperative_groups.h>
#else
#include "context.h" #include <cooperative_groups.h>
#include "cublas_wrappers.h" #endif
#include <curand_kernel.h>
#define THREADS 256
#define TILE_DIM 32 #include "context.h"
#include "cublas_wrappers.h"
#define minus_infinity -1 * std::numeric_limits<float>::infinity()
#define THREADS 256
#define FINAL_MASK 0xffffffff #define TILE_DIM 32
template <typename T> #define minus_infinity -1 * std::numeric_limits<float>::infinity()
void launch_fused_add2(T* out,
const T* inp1, #define FINAL_MASK 0xffffffff
const T* inp2,
int batch_size, template <typename T>
int seq_length, void launch_fused_add2(T* out,
int hidden_size, const T* inp1,
cudaStream_t& stream); const T* inp2,
int batch_size,
template <typename T> int seq_length,
void launch_fused_add4(T* out, int hidden_size,
const T* inp1, cudaStream_t& stream);
const T* inp2,
const T* inp3, template <typename T>
const T* inp4, void launch_fused_add4(T* out,
int batch_size, const T* inp1,
int seq_length, const T* inp2,
int hidden_size, const T* inp3,
cudaStream_t& stream); const T* inp4,
int batch_size,
template <typename T> int seq_length,
void launch_fused_add3(T* out, int hidden_size,
const T* inp1, cudaStream_t& stream);
const T* inp2,
const T* inp3, template <typename T>
int batch_size, void launch_fused_add3(T* out,
int seq_length, const T* inp1,
int hidden_size, const T* inp2,
cudaStream_t& stream); const T* inp3,
int batch_size,
int seq_length,
int hidden_size,
cudaStream_t& stream);
// !!! This is a file automatically generated by hipify!!!
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hiprand/hiprand_kernel.h>
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#define THREADS 256
#define TILE_DIM 32
#define minus_infinity -1 * std::numeric_limits<float>::infinity()
#define FINAL_MASK 0xffffffff
template <typename T>
void launch_fused_add2(T* out,
const T* inp1,
const T* inp2,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream);
template <typename T>
void launch_fused_add4(T* out,
const T* inp1,
const T* inp2,
const T* inp3,
const T* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream);
template <typename T>
void launch_fused_add3(T* out,
const T* inp1,
const T* inp2,
const T* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream);
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <fstream> #include <fstream>
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
using namespace std; using namespace std;
template <typename T> template <typename T>
class Normalize_Layer { class Normalize_Layer {
public: public:
struct Config { struct Config {
uint32_t batchSize; uint32_t batchSize;
uint32_t seqLength; uint32_t seqLength;
uint32_t hiddenDim; uint32_t hiddenDim;
float epsilon; float epsilon;
bool training; bool training;
bool useMean; bool useMean;
Config(uint32_t batch, Config(uint32_t batch,
uint32_t seq, uint32_t seq,
uint32_t h, uint32_t h,
float epsilon = 1e-12, float epsilon = 1e-12,
bool training = true, bool training = true,
bool useMean = true) bool useMean = true)
: batchSize(batch), : batchSize(batch),
seqLength(seq), seqLength(seq),
hiddenDim(h), hiddenDim(h),
epsilon(epsilon), epsilon(epsilon),
training(training), training(training),
useMean(useMean) useMean(useMean)
{ {
} }
}; };
Normalize_Layer(Config config) Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{ {
} }
~Normalize_Layer() {} ~Normalize_Layer() {}
void ForwardCheckpoint(int bsz, // batch * seq void ForwardCheckpoint(int bsz, // batch * seq
T* vals, T* vals,
const T* residual, const T* residual,
const T* gamma, const T* gamma,
const T* betta, const T* betta,
cudaStream_t& stream, cudaStream_t& stream,
bool preLayerNorm = false) bool preLayerNorm = false)
{ {
launch_bias_residual_layer_norm(vals, launch_bias_residual_layer_norm(vals,
residual, residual,
gamma, gamma,
betta, betta,
config_.epsilon, config_.epsilon,
bsz, bsz,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
preLayerNorm, preLayerNorm,
config_.training, config_.training,
vars, vars,
means); means);
} }
void Forward(int bsz, void Forward(int bsz,
T* vals, T* vals,
const T* residual, const T* residual,
const T* gamma, const T* gamma,
const T* betta, const T* betta,
cudaStream_t& stream, cudaStream_t& stream,
bool preLayerNorm = false) bool preLayerNorm = false)
{ {
launch_bias_residual_layer_norm(vals, launch_bias_residual_layer_norm(vals,
residual, residual,
gamma, gamma,
betta, betta,
config_.epsilon, config_.epsilon,
bsz, bsz,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
preLayerNorm, preLayerNorm,
config_.training, config_.training,
vars); vars);
} }
void Backward(int bsz, void Backward(int bsz,
const T* out_grad, const T* out_grad,
const T* gamma, const T* gamma,
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out, T* inp_grad_out,
const T* norm_in = nullptr) const T* norm_in = nullptr)
{ {
launch_layerNorm_backward(out_grad, launch_layerNorm_backward(out_grad,
norm_in, norm_in,
vars, vars,
means, means,
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
inp_grad_out, inp_grad_out,
bsz, bsz,
config_.hiddenDim, config_.hiddenDim,
stream); stream);
} }
void Backward(int bsz, void Backward(int bsz,
const T* out_grad, const T* out_grad,
const T* gamma, const T* gamma,
const T* betta, const T* betta,
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out, T* inp_grad_out,
const T* norm_out) const T* norm_out)
{ {
launch_layerNorm_backward(out_grad, launch_layerNorm_backward(out_grad,
norm_out, norm_out,
vars, vars,
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
inp_grad_out, inp_grad_out,
bsz, bsz,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
!config_.useMean, !config_.useMean,
betta); betta);
} }
void BackwardFusedAdd(int bsz, void BackwardFusedAdd(int bsz,
const T* out_grad1, const T* out_grad1,
const T* out_grad2, const T* out_grad2,
const T* gamma, const T* gamma,
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out, T* inp_grad_out,
const T* norm_in = nullptr) const T* norm_in = nullptr)
{ {
launch_layerNorm_backward_fused_add(out_grad1, launch_layerNorm_backward_fused_add(out_grad1,
out_grad2, out_grad2,
norm_in, norm_in,
vars, vars,
means, means,
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
inp_grad_out, inp_grad_out,
bsz, bsz,
config_.hiddenDim, config_.hiddenDim,
stream); stream);
} }
void BackwardFusedAdd(int bsz, void BackwardFusedAdd(int bsz,
const T* out_grad1, const T* out_grad1,
const T* out_grad2, const T* out_grad2,
const T* gamma, const T* gamma,
const T* betta, const T* betta,
T* gamma_grad, T* gamma_grad,
T* betta_grad, T* betta_grad,
cudaStream_t stream[2], cudaStream_t stream[2],
T* inp_grad_out, T* inp_grad_out,
const T* norm_out) const T* norm_out)
{ {
launch_layerNorm_backward_fused_add(out_grad1, launch_layerNorm_backward_fused_add(out_grad1,
out_grad2, out_grad2,
norm_out, norm_out,
vars, vars,
gamma, gamma,
gamma_grad, gamma_grad,
betta_grad, betta_grad,
inp_grad_out, inp_grad_out,
bsz, bsz,
config_.hiddenDim, config_.hiddenDim,
stream, stream,
!config_.useMean, !config_.useMean,
betta); betta);
} }
inline bool UseMean() const { return config_.useMean; } inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance) inline void SetVar(T* variance)
{ {
if (!variance) { throw std::runtime_error("Normalize variance is null."); } if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance; vars = variance;
} }
inline void SetMean(T* mean) inline void SetMean(T* mean)
{ {
if (!mean) { throw std::runtime_error("Normalize mean is null."); } if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean; means = mean;
} }
private: private:
Config config_; Config config_;
T* vars; T* vars;
T* means; T* means;
T* vals_hat; T* vals_hat;
}; };
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <fstream>
#include "custom_hip_layers.h"
using namespace std;
template <typename T>
class Normalize_Layer {
public:
struct Config {
uint32_t batchSize;
uint32_t seqLength;
uint32_t hiddenDim;
float epsilon;
bool training;
bool useMean;
Config(uint32_t batch,
uint32_t seq,
uint32_t h,
float epsilon = 1e-12,
bool training = true,
bool useMean = true)
: batchSize(batch),
seqLength(seq),
hiddenDim(h),
epsilon(epsilon),
training(training),
useMean(useMean)
{
}
};
Normalize_Layer(Config config)
: config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr)
{
}
~Normalize_Layer() {}
void ForwardCheckpoint(int bsz, // batch * seq
T* vals,
const T* residual,
const T* gamma,
const T* betta,
hipStream_t& stream,
bool preLayerNorm = false)
{
launch_bias_residual_layer_norm(vals,
residual,
gamma,
betta,
config_.epsilon,
bsz,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars,
means);
}
void Forward(int bsz,
T* vals,
const T* residual,
const T* gamma,
const T* betta,
hipStream_t& stream,
bool preLayerNorm = false)
{
launch_bias_residual_layer_norm(vals,
residual,
gamma,
betta,
config_.epsilon,
bsz,
config_.hiddenDim,
stream,
preLayerNorm,
config_.training,
vars);
}
void Backward(int bsz,
const T* out_grad,
const T* gamma,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward(out_grad,
norm_in,
vars,
means,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream);
}
void Backward(int bsz,
const T* out_grad,
const T* gamma,
const T* betta,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward(out_grad,
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream,
!config_.useMean,
betta);
}
void BackwardFusedAdd(int bsz,
const T* out_grad1,
const T* out_grad2,
const T* gamma,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_in = nullptr)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
norm_in,
vars,
means,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream);
}
void BackwardFusedAdd(int bsz,
const T* out_grad1,
const T* out_grad2,
const T* gamma,
const T* betta,
T* gamma_grad,
T* betta_grad,
hipStream_t stream[2],
T* inp_grad_out,
const T* norm_out)
{
launch_layerNorm_backward_fused_add(out_grad1,
out_grad2,
norm_out,
vars,
gamma,
gamma_grad,
betta_grad,
inp_grad_out,
bsz,
config_.hiddenDim,
stream,
!config_.useMean,
betta);
}
inline bool UseMean() const { return config_.useMean; }
inline void SetVar(T* variance)
{
if (!variance) { throw std::runtime_error("Normalize variance is null."); }
vars = variance;
}
inline void SetMean(T* mean)
{
if (!mean) { throw std::runtime_error("Normalize mean is null."); }
means = mean;
}
private:
Config config_;
T* vars;
T* means;
T* vals_hat;
};
#pragma once
#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <cooperative_groups.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include <stdlib.h>
#include <cassert>
#include <iostream>
#pragma once
#if (__x86_64__ || __i386__)
#include <cpuid.h>
#include <x86intrin.h>
#endif
#define TILE (128 * 1024 * 1024)
#if defined(__AVX512__) or defined(__AVX256__)
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
#if defined(__AVX512__)
#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm512_loadu_ps(x)
#define SIMD_SET(x) _mm512_set1_ps(x)
#define SIMD_ADD(x, y) _mm512_add_ps(x, y)
#define SIMD_MUL(x, y) _mm512_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_WIDTH 16
#define SIMD_LOAD2(x, h) \
((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm512_storeu_ps(x, d))
#define INTV __m256i
#elif defined(__AVX256__)
#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d)
#define SIMD_LOAD(x) _mm256_loadu_ps(x)
#define SIMD_SET(x) _mm256_set1_ps(x)
#define SIMD_ADD(x, y) _mm256_add_ps(x, y)
#define SIMD_MUL(x, y) _mm256_mul_ps(x, y)
#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_WIDTH 8
#define SIMD_LOAD2(x, h) \
((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm256_storeu_ps(x, d))
#define INTV __m128i
#endif
union AVX_Data {
#if defined(__AVX512__)
__m512 data;
#elif defined(__AVX256__)
__m256 data;
#endif
// float data_f[16];
};
template <int span>
inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) {
SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision);
}
}
template <int span>
inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) {
dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision);
}
}
template <int span>
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) {
dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data);
}
}
template <int span>
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) {
dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data);
}
}
template <int span>
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) {
dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data);
}
}
template <int span>
inline void simd_sqrt(AVX_Data* dst, AVX_Data* src)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); }
}
template <int span>
inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); }
}
template <int span>
inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); }
}
template <int span>
inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); }
}
template <int span>
inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); }
}
template <int span>
inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r)
{
#pragma unroll
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); }
}
#endif
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
#include <fstream> #include <fstream>
using namespace std; using namespace std;
template <typename T> template <typename T>
class Softmax { class Softmax {
public: public:
struct Config { struct Config {
size_t batchSize; size_t batchSize;
size_t heads; size_t heads;
size_t seq_length; size_t seq_length;
size_t prob_depth; size_t prob_depth;
float temprature; float temperature;
bool mem_alloc; bool mem_alloc;
Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false) Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false)
: batchSize(batch), : batchSize(batch),
heads(h), heads(h),
seq_length(seq), seq_length(seq),
prob_depth(prob_size), prob_depth(prob_size),
temprature(1.0), temperature(1.0),
mem_alloc(mem_alloc) mem_alloc(mem_alloc)
{ {
} }
}; };
Softmax(Config config) : config_(config) {} Softmax(Config config) : config_(config) {}
~Softmax() {} ~Softmax() {}
void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream) void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream)
{ {
launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream); launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream);
} }
void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream) void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream)
{ {
launch_attn_softmax_backward_v2<T>( launch_attn_softmax_backward_v2<T>(
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
} }
inline size_t GetProbDepth() const { return config_.prob_depth; } inline size_t GetProbDepth() const { return config_.prob_depth; }
inline size_t GetBatchSize() const { return config_.batchSize; } inline size_t GetBatchSize() const { return config_.batchSize; }
inline size_t GetNumHeads() const { return config_.heads; } inline size_t GetNumHeads() const { return config_.heads; }
inline size_t GetSeqLength() const { return config_.seq_length; } inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private: private:
Config config_; Config config_;
}; };
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "custom_hip_layers.h"
#include <fstream>
using namespace std;
template <typename T>
class Softmax {
public:
struct Config {
size_t batchSize;
size_t heads;
size_t seq_length;
size_t prob_depth;
float temperature;
bool mem_alloc;
Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false)
: batchSize(batch),
heads(h),
seq_length(seq),
prob_depth(prob_size),
temperature(1.0),
mem_alloc(mem_alloc)
{
}
};
Softmax(Config config) : config_(config) {}
~Softmax() {}
void Forward(int bsz, T* vals, const T* attn_mask, hipStream_t& stream)
{
launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream);
}
void Backward(int bsz, T* out_grad, const T* soft_out, hipStream_t stream)
{
launch_attn_softmax_backward_v2<T>(
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
inline size_t GetProbDepth() const { return config_.prob_depth; }
inline size_t GetBatchSize() const { return config_.batchSize; }
inline size_t GetNumHeads() const { return config_.heads; }
inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
};
#pragma once #pragma once
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include "context.h" #include "context.h"
template <typename T> template <typename T>
class StridedBatchGemm { class StridedBatchGemm {
public: public:
struct Config { struct Config {
int batch_size; int batch_size;
int m; int m;
int n; int n;
int k; int k;
float alpha; float alpha;
float beta; float beta;
cublasOperation_t op_A; cublasOperation_t op_A;
cublasOperation_t op_B; cublasOperation_t op_B;
std::array<int, 3> gemm_algos; std::array<int, 3> gemm_algos;
Config(int batch, Config(int batch,
int mm, int mm,
int nn, int nn,
int kk, int kk,
float param_alpha, float param_alpha,
float param_beta, float param_beta,
cublasOperation_t opA, cublasOperation_t opA,
cublasOperation_t opB, cublasOperation_t opB,
const std::array<int, 3>& algos) const std::array<int, 3>& algos)
: batch_size(batch), : batch_size(batch),
m(mm), m(mm),
n(nn), n(nn),
k(kk), k(kk),
alpha(param_alpha), alpha(param_alpha),
beta(param_beta), beta(param_beta),
op_A(opA), op_A(opA),
op_B(opB), op_B(opB),
gemm_algos(algos) gemm_algos(algos)
{ {
} }
void SetConfig(int mm, int nn, int kk) void SetConfig(int mm, int nn, int kk)
{ {
m = mm; m = mm;
n = nn; n = nn;
k = kk; k = kk;
} }
}; };
StridedBatchGemm(const Config& config) : _config(config) {} StridedBatchGemm(const Config& config) : _config(config) {}
virtual ~StridedBatchGemm() {} virtual ~StridedBatchGemm() {}
void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
{ {
int stride_a = _config.m * _config.k; int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k; int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n; int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle, cublas_strided_batched_gemm(handle,
_config.m, _config.m,
_config.n, _config.n,
_config.k, _config.k,
&_config.alpha, &_config.alpha,
&_config.beta, &_config.beta,
_buffer_a, _buffer_a,
_buffer_b, _buffer_b,
output, output,
_config.op_A, _config.op_A,
_config.op_B, _config.op_B,
stride_a, stride_a,
stride_b, stride_b,
stride_c, stride_c,
bsz, bsz,
cublasGemmAlgo_t(_config.gemm_algos[0])); #ifdef __HIP_PLATFORM_HCC__
} rocblas_gemm_algo(_config.gemm_algos[0]));
#else
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) cublasGemmAlgo_t(_config.gemm_algos[0]));
{ #endif
int stride_a = _config.m * _config.k; }
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n; void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle)
{
cublas_strided_batched_gemm(handle, int stride_a = _config.m * _config.k;
_config.m, int stride_b = _config.n * _config.k;
_config.n, int stride_c = _config.m * _config.n;
_config.k,
&_config.alpha, cublas_strided_batched_gemm(handle,
&_config.beta, _config.m,
_buffer_a, _config.n,
_buffer_b, _config.k,
output, &_config.alpha,
_config.op_A, &_config.beta,
_config.op_B, _buffer_a,
stride_a, _buffer_b,
stride_b, output,
stride_c, _config.op_A,
_config.batch_size, _config.op_B,
cublasGemmAlgo_t(_config.gemm_algos[0])); stride_a,
stride_b,
k_buf = _buffer_a; stride_c,
q_buf = _buffer_b; _config.batch_size,
} #ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[0]));
void Backward(int bsz, #else
const T* d_output, cublasGemmAlgo_t(_config.gemm_algos[0]));
const T* _buffer_a, #endif
const T* _buffer_b,
cublasHandle_t handle, k_buf = _buffer_a;
T* inpGradA = nullptr, q_buf = _buffer_b;
T* inpGradB = nullptr) }
{
int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); void Backward(int bsz,
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); const T* d_output,
const T* _buffer_a,
int stride_a = mb * _config.n; const T* _buffer_b,
int stride_b = _config.n * kb; cublasHandle_t handle,
int stride_c = _config.m * _config.k; T* inpGradA = nullptr,
T* inpGradB = nullptr)
// B need to transpose. {
cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m);
int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k);
// Calculate d_A.
cublas_strided_batched_gemm(handle, int stride_a = mb * _config.n;
mb, int stride_b = _config.n * kb;
kb, int stride_c = _config.m * _config.k;
_config.n,
&_config.alpha, // B need to transpose.
&_config.beta, cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
(_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
(_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), // Calculate d_A.
inpGradA, cublas_strided_batched_gemm(handle,
CUBLAS_OP_N, mb,
op_b, kb,
stride_a, _config.n,
stride_b, &_config.alpha,
stride_c, &_config.beta,
bsz, (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output),
cublasGemmAlgo_t(_config.gemm_algos[1])); (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b),
inpGradA,
// A need to transpose. CUBLAS_OP_N,
cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); op_b,
stride_a,
stride_a = _config.m * _config.k; stride_b,
stride_b = _config.m * _config.n; stride_c,
stride_c = _config.n * _config.k; bsz,
#ifdef __HIP_PLATFORM_HCC__
// Calculate d_B. rocblas_gemm_algo(_config.gemm_algos[1]));
cublas_strided_batched_gemm(handle, #else
_config.k, cublasGemmAlgo_t(_config.gemm_algos[1]));
_config.n, #endif
_config.m,
&_config.alpha, // A need to transpose.
&_config.beta, cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T);
_buffer_a,
d_output, stride_a = _config.m * _config.k;
inpGradB, stride_b = _config.m * _config.n;
op_a, stride_c = _config.n * _config.k;
CUBLAS_OP_N,
stride_a, // Calculate d_B.
stride_b, cublas_strided_batched_gemm(handle,
stride_c, _config.k,
bsz, _config.n,
cublasGemmAlgo_t(_config.gemm_algos[2])); _config.m,
} &_config.alpha,
&_config.beta,
inline int GetN() const { return _config.k; } _buffer_a,
d_output,
inline const T* GetBufferA() const { return k_buf; } inpGradB,
op_a,
inline const T* GetBufferB() const { return q_buf; } CUBLAS_OP_N,
stride_a,
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } stride_b,
stride_c,
private: bsz,
Config _config; #ifdef __HIP_PLATFORM_HCC__
const T* q_buf; rocblas_gemm_algo(_config.gemm_algos[2]));
const T* k_buf; #else
}; cublasGemmAlgo_t(_config.gemm_algos[2]));
#endif
}
inline int GetN() const { return _config.k; }
inline const T* GetBufferA() const { return k_buf; }
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
const T* k_buf;
};
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "context_hip.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int batch_size;
int m;
int n;
int k;
float alpha;
float beta;
rocblas_operation op_A;
rocblas_operation op_B;
std::array<int, 3> gemm_algos;
Config(int batch,
int mm,
int nn,
int kk,
float param_alpha,
float param_beta,
rocblas_operation opA,
rocblas_operation opB,
const std::array<int, 3>& algos)
: batch_size(batch),
m(mm),
n(nn),
k(kk),
alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
#endif
}
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
_config.batch_size,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
#endif
k_buf = _buffer_a;
q_buf = _buffer_b;
}
void Backward(int bsz,
const T* d_output,
const T* _buffer_a,
const T* _buffer_b,
rocblas_handle handle,
T* inpGradA = nullptr,
T* inpGradB = nullptr)
{
int mb = (_config.op_A == rocblas_operation_transpose ? _config.k : _config.m);
int kb = (_config.op_A == rocblas_operation_transpose ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
rocblas_operation op_b = (_config.op_B == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
_config.n,
&_config.alpha,
&_config.beta,
(_config.op_A == rocblas_operation_transpose ? _buffer_b : d_output),
(_config.op_A == rocblas_operation_transpose ? d_output : _buffer_b),
inpGradA,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[1]));
#else
cublasGemmAlgo_t(_config.gemm_algos[1]));
#endif
// A need to transpose.
rocblas_operation op_a = (_config.op_A == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
_config.k,
_config.n,
_config.m,
&_config.alpha,
&_config.beta,
_buffer_a,
d_output,
inpGradB,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[2]));
#else
cublasGemmAlgo_t(_config.gemm_algos[2]));
#endif
}
inline int GetN() const { return _config.k; }
inline const T* GetBufferA() const { return k_buf; }
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
const T* k_buf;
};
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ /* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h> #include <ATen/ATen.h>
// Forward/backward compatiblity hack around // Forward/backward compatibility hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288 // https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream. // pending more future-proof guidance from upstream.
// struct TypeShim // struct TypeShim
...@@ -26,6 +26,11 @@ ...@@ -26,6 +26,11 @@
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
...@@ -46,6 +51,11 @@ ...@@ -46,6 +51,11 @@
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
......
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatibility hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T* x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
...@@ -61,7 +61,7 @@ at::Tensor lamb(at::Tensor& p, ...@@ -61,7 +61,7 @@ at::Tensor lamb(at::Tensor& p,
// intermediate for weight L2 reduction // intermediate for weight L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the // make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected // behaviour is unexpected
at::Tensor w_l2_i = at::empty( at::Tensor w_l2_i = at::empty(
{512}, {512},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
...@@ -69,7 +69,7 @@ at::Tensor lamb(at::Tensor& p, ...@@ -69,7 +69,7 @@ at::Tensor lamb(at::Tensor& p,
// intermediate for update L2 reduction // intermediate for update L2 reduction
// make sure that the threads per block is at least 512 during the kernel launch otherwise the // make sure that the threads per block is at least 512 during the kernel launch otherwise the
// behavious is unexpected // behaviour is unexpected
at::Tensor u_l2_i = at::empty( at::Tensor u_l2_i = at::empty(
{512}, {512},
p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float p.options().dtype(p.type().scalarType() == at::ScalarType::Half ? at::ScalarType::Float
......
...@@ -8,13 +8,16 @@ ...@@ -8,13 +8,16 @@
#include "ATen/cuda/CUDAContext.h" #include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh" #include "ATen/cuda/detail/IndexUtils.cuh"
//#include "ATen/Type.h" //#include "ATen/Type.h"
#include <THC/THCGeneral.h>
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <iostream> #include <iostream>
//#include <helper_functions.h> //#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#endif
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <stdio.h> #include <stdio.h>
...@@ -30,8 +33,10 @@ struct SharedMemory { ...@@ -30,8 +33,10 @@ struct SharedMemory {
// Ensure that we won't compile any un-specialized types // Ensure that we won't compile any un-specialized types
__device__ inline operator T*() __device__ inline operator T*()
{ {
#ifndef _WIN32
extern __device__ void error(void); extern __device__ void error(void);
error(); error();
#endif
return NULL; return NULL;
} }
}; };
...@@ -281,13 +286,13 @@ __global__ void lamb_cuda_kernel_part3( ...@@ -281,13 +286,13 @@ __global__ void lamb_cuda_kernel_part3(
float lamb_coeff = 1.0; float lamb_coeff = 1.0;
if (reg_w != 0 and reg_u != 0) { if (reg_w != 0 && reg_u != 0) {
lamb_coeff = reg_w / reg_u; lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; } if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; } if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
} }
if (blockId == 0 and threadIdInBlock == 0) { if (blockId == 0 && threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff; lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff); // printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
} }
...@@ -462,7 +467,7 @@ void fused_lamb_cuda(at::Tensor& p, ...@@ -462,7 +467,7 @@ void fused_lamb_cuda(at::Tensor& p,
lamb_coeff.data<scalar_t>()); lamb_coeff.data<scalar_t>());
})); }));
} }
THCudaCheck(cudaGetLastError()); C10_CUDA_CHECK(cudaGetLastError());
} }
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a, // template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
......
// !!! This is a file automatically generated by hipify!!!
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
#ifndef _WIN32
extern __device__ void error(void);
error();
#endif
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "type_shim_hip.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 && reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 && threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
C10_HIP_CHECK(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_cuda_layers.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def(
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_asym_fp32",
&ds_sr_quantize_asym<float>,
"DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_hip_layers.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
template <typename T>
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def(
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_asym_fp32",
&ds_sr_quantize_asym<float>,
"DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
}
#include <math.h>
#include "custom_cuda_layers.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
b.sync();
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
q_data_int[0].x =
(rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
: q_data_int[0].x;
q_data_int[0].y =
(rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
: q_data_int[0].y;
q_data_int[1].x =
(rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
: q_data_int[1].x;
q_data_int[1].y =
(rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x / q_scale_val;
data_f[0].y = q_data_int[0].y / q_scale_val;
data_f[1].x = q_data_int[1].x / q_scale_val;
data_f[1].y = q_data_int[1].y / q_scale_val;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
data[reg_count] = vals_cast[group_index];
if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)(q_data.x * q_scale_val));
q_data_int.y = (float)((int)(q_data.y * q_scale_val));
q_data_int.w = (float)((int)(q_data.w * q_scale_val));
q_data_int.z = (float)((int)(q_data.z * q_scale_val));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
q_data_int.x =
(rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
: q_data_int.x;
q_data_int.y =
(rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
: q_data_int.y;
q_data_int.w =
(rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
: q_data_int.w;
q_data_int.z =
(rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
: q_data_int.z;
q_data_int.x /= q_scale_val;
q_data_int.y /= q_scale_val;
q_data_int.w /= q_scale_val;
q_data_int.z /= q_scale_val;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
float min = 10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (((float)data_h[0]) > max) max = (float)data_h[0];
if (((float)data_h[1]) > max) max = (float)data_h[1];
if (((float)data_h[2]) > max) max = (float)data_h[2];
if (((float)data_h[3]) > max) max = (float)data_h[3];
if (((float)data_h[0]) < min) min = (float)data_h[0];
if (((float)data_h[1]) < min) min = (float)data_h[1];
if (((float)data_h[2]) < min) min = (float)data_h[2];
if (((float)data_h[3]) < min) min = (float)data_h[3];
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
q_data_int[0].x = q_data_int[0].x * q_scale + min;
q_data_int[0].y = q_data_int[0].y * q_scale + min;
q_data_int[1].x = q_data_int[1].x * q_scale + min;
q_data_int[1].y = q_data_int[1].y * q_scale + min;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
float min = 10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
q_data.x = q_data_int.x * q_scale + min;
q_data.y = q_data_int.y * q_scale + min;
q_data.w = q_data_int.w * q_scale + min;
q_data.z = q_data_int.z * q_scale + min;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_val_inv = 1 / q_scale_val;
float high_q = (float)((1 << num_bits) - 1);
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] =
abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[1] =
abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
q_error[2] =
abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[3] =
abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
? (q_data_int[0].x + 1)
: q_data_int[0].x;
q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
? (q_data_int[0].y + 1)
: q_data_int[0].y;
q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
? (q_data_int[1].x + 1)
: q_data_int[1].x;
q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
? (q_data_int[1].y + 1)
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x * q_scale_val + min;
data_f[0].y = q_data_int[0].y * q_scale_val + min;
data_f[1].x = q_data_int[1].x * q_scale_val + min;
data_f[1].y = q_data_int[1].y * q_scale_val + min;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float high_q = (float)((1 << num_bits) - 1);
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
: q_data_int.x;
q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
: q_data_int.y;
q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
: q_data_int.w;
q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
: q_data_int.z;
q_data_int.x = q_data_int.x * q_scale_val + min;
q_data_int.y = q_data_int.y * q_scale_val + min;
q_data_int.w = q_data_int.w * q_scale_val + min;
q_data_int.z = q_data_int.z * q_scale_val + min;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
b.sync();
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
q_data_int[0].x =
(rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
: q_data_int[0].x;
q_data_int[0].y =
(rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
: q_data_int[0].y;
q_data_int[1].x =
(rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
: q_data_int[1].x;
q_data_int[1].y =
(rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x / q_scale_val;
data_f[0].y = q_data_int[0].y / q_scale_val;
data_f[1].x = q_data_int[1].x / q_scale_val;
data_f[1].y = q_data_int[1].y / q_scale_val;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
data[reg_count] = vals_cast[group_index];
if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)(q_data.x * q_scale_val));
q_data_int.y = (float)((int)(q_data.y * q_scale_val));
q_data_int.w = (float)((int)(q_data.w * q_scale_val));
q_data_int.z = (float)((int)(q_data.z * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
q_data_int.x =
(rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
: q_data_int.x;
q_data_int.y =
(rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
: q_data_int.y;
q_data_int.w =
(rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
: q_data_int.w;
q_data_int.z =
(rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
: q_data_int.z;
q_data_int.x /= q_scale_val;
q_data_int.y /= q_scale_val;
q_data_int.w /= q_scale_val;
q_data_int.z /= q_scale_val;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
float min = 10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (((float)data_h[0]) > max) max = (float)data_h[0];
if (((float)data_h[1]) > max) max = (float)data_h[1];
if (((float)data_h[2]) > max) max = (float)data_h[2];
if (((float)data_h[3]) > max) max = (float)data_h[3];
if (((float)data_h[0]) < min) min = (float)data_h[0];
if (((float)data_h[1]) < min) min = (float)data_h[1];
if (((float)data_h[2]) < min) min = (float)data_h[2];
if (((float)data_h[3]) < min) min = (float)data_h[3];
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
q_data_int[0].x = q_data_int[0].x * q_scale + min;
q_data_int[0].y = q_data_int[0].y * q_scale + min;
q_data_int[1].x = q_data_int[1].x * q_scale + min;
q_data_int[1].y = q_data_int[1].y * q_scale + min;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
float min = 10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
q_data.x = q_data_int.x * q_scale + min;
q_data.y = q_data_int.y * q_scale + min;
q_data.w = q_data_int.w * q_scale + min;
q_data.z = q_data_int.z * q_scale + min;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel_asym), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_val_inv = 1 / q_scale_val;
float high_q = (float)((1 << num_bits) - 1);
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] =
abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[1] =
abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
q_error[2] =
abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[3] =
abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
? (q_data_int[0].x + 1)
: q_data_int[0].x;
q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
? (q_data_int[0].y + 1)
: q_data_int[0].y;
q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
? (q_data_int[1].x + 1)
: q_data_int[1].x;
q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
? (q_data_int[1].y + 1)
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x * q_scale_val + min;
data_f[0].y = q_data_int[0].y * q_scale_val + min;
data_f[1].x = q_data_int[1].x * q_scale_val + min;
data_f[1].y = q_data_int[1].y * q_scale_val + min;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float high_q = (float)((1 << num_bits) - 1);
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
: q_data_int.x;
q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
: q_data_int.y;
q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
: q_data_int.w;
q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
: q_data_int.z;
q_data_int.x = q_data_int.x * q_scale_val + min;
q_data_int.y = q_data_int.y * q_scale_val + min;
q_data_int.w = q_data_int.w * q_scale_val + min;
q_data_int.z = q_data_int.z * q_scale_val + min;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment