Commit 67ea635f authored by aiss's avatar aiss
Browse files

push dsv0.8.2 version

parent 1b2721ad
Pipeline #201 failed with stages
in 0 seconds
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include "ds_kernel_utils.h"
#include <cuda_fp16.h>
#include <stdint.h>
#ifdef BF16_AVAILABLE
#include <cuda_bf16.h>
#endif
namespace conversion {
// Basic primitive for constructing conversions
template <typename TO, typename FROM>
DS_D_INLINE TO to(FROM val)
{
return to(val);
}
// Specializations
/********************* Identity Conversions *********************/
/*
Identity conversions are useful in templated functions where we might have
a fixed destination type. For example, I might have a kernel that accepts
__half, __nv_bfloat16, and float but always want to do the core computation
at floating point:
T mem_value = input[idx];
float compute_value = conversion::to<float, T>(mem_value);
In practice, we should be able to elide the second template parameter:
float compute_val = conversion::to<float>(mem_value);
In this case, we need an implementation to handle the T = float case
NOTE: The type inferencing system appears to be unable to handle inferring the first
template parameter, even in the trivial case.
*/
// Floating point types
template <>
DS_D_INLINE double to(double val)
{
return val;
}
template <>
DS_D_INLINE float to(float val)
{
return val;
}
template <>
DS_D_INLINE __half to(__half val)
{
return val;
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val)
{
return val;
}
#endif
// Integer types
template <>
DS_D_INLINE int8_t to(int8_t val)
{
return val;
}
template <>
DS_D_INLINE uint8_t to(uint8_t val)
{
return val;
}
template <>
DS_D_INLINE int16_t to(int16_t val)
{
return val;
}
template <>
DS_D_INLINE uint16_t to(uint16_t val)
{
return val;
}
template <>
DS_D_INLINE int32_t to(int32_t val)
{
return val;
}
template <>
DS_D_INLINE uint32_t to(uint32_t val)
{
return val;
}
template <>
DS_D_INLINE int64_t to(int64_t val)
{
return val;
}
template <>
DS_D_INLINE uint64_t to(uint64_t val)
{
return val;
}
// TODO: evaluate if we want bools
/********************* To Double Conversions *********************/
// * to double variants
// Would normally like to not use C cast, but this is an important enough conversion
// to keep
template <>
DS_D_INLINE double to(float val)
{
#ifdef PTX_AVAILABLE
double ret_val;
asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val));
return ret_val;
#else
return double(val);
#endif
}
// Note: there is a CVT instruction for __half -> double, but there's no inline interface
// for passing a single half value
template <>
DS_D_INLINE double to(__half val)
{
return to<double>(__half2float(val));
}
template <>
DS_D_INLINE double to(int64_t val)
{
return __ll2double_rn(val);
}
template <>
DS_D_INLINE double to(int32_t val)
{
return __int2double_rn(val);
}
template <>
DS_D_INLINE double to(int16_t val)
{
return __int2double_rn(val);
}
template <>
DS_D_INLINE double to(int8_t val)
{
return __int2double_rn(val);
}
template <>
DS_D_INLINE double to(uint64_t val)
{
return __ull2double_rn(val);
}
template <>
DS_D_INLINE double to(uint32_t val)
{
return __uint2double_rn(val);
}
template <>
DS_D_INLINE double to(uint16_t val)
{
return __uint2double_rn(val);
}
template <>
DS_D_INLINE double to(uint8_t val)
{
return __uint2double_rn(val);
}
// Same applies here
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE double to(__nv_bfloat16 val)
{
return to<double>(__bfloat162float(val));
}
#endif
/********************* To Float Conversions *********************/
template <>
DS_D_INLINE float to(double val)
{
return __double2float_rn(val);
}
template <>
DS_D_INLINE float to(__half val)
{
return __half2float(val);
}
template <>
DS_D_INLINE float to(int64_t val)
{
return __ll2float_rn(val);
}
template <>
DS_D_INLINE float to(int32_t val)
{
return __int2float_rn(val);
}
template <>
DS_D_INLINE float to(int16_t val)
{
return __int2float_rn(val);
}
template <>
DS_D_INLINE float to(int8_t val)
{
return __int2float_rn(val);
}
template <>
DS_D_INLINE float to(uint64_t val)
{
return __ull2float_rn(val);
}
template <>
DS_D_INLINE float to(uint32_t val)
{
return __uint2float_rn(val);
}
template <>
DS_D_INLINE float to(uint16_t val)
{
return __uint2float_rn(val);
}
template <>
DS_D_INLINE float to(uint8_t val)
{
return __uint2float_rn(val);
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE float to(__nv_bfloat16 val)
{
return __bfloat162float(val);
}
#endif
/********************* To Float2 Conversions *********************/
template <>
DS_D_INLINE float2 to(__half2 val)
{
return __half22float2(val);
}
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE float2 to(__nv_bfloat162 val)
{
return __bfloat1622float2(val);
}
#endif
/********************* To Half Conversions *********************/
//aiss
//template <>
//DS_D_INLINE __half to(double val)
//{
// return __double2half(val);
//}
template <>
DS_D_INLINE __half to(float val)
{
return __float2half(val);
}
template <>
DS_D_INLINE __half to(int64_t val)
{
return __ll2half_rn(val);
}
template <>
DS_D_INLINE __half to(int32_t val)
{
return __int2half_rn(val);
}
template <>
DS_D_INLINE __half to(int16_t val)
{
return __short2half_rn(val);
}
template <>
DS_D_INLINE __half to(int8_t val)
{
return __int2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint64_t val)
{
return __ull2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint32_t val)
{
return __uint2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint16_t val)
{
return __ushort2half_rn(val);
}
template <>
DS_D_INLINE __half to(uint8_t val)
{
return __uint2half_rn(val);
}
#ifdef BF16_AVAILABLE
// No direct conversion
template <>
DS_D_INLINE __half to(__nv_bfloat16 val)
{
return to<__half>(to<float>(val));
}
#endif
/********************* To Half2 Conversions *********************/
template <>
DS_D_INLINE __half2 to(float2 val)
{
return __float22half2_rn(val);
}
#ifdef BF16_AVAILABLE
// No direct conversion
template <>
DS_D_INLINE __half2 to(__nv_bfloat162 val)
{
return to<__half2>(to<float2>(val));
}
#endif
/********************* To BF16 Conversions *********************/
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat16 to(double val)
{
return __double2bfloat16(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(float val)
{
return __float2bfloat16(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int64_t val)
{
return __ll2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int32_t val)
{
return __int2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int16_t val)
{
return __short2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(int8_t val)
{
return __int2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint64_t val)
{
return __ull2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint32_t val)
{
return __uint2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint16_t val)
{
return __ushort2bfloat16_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat16 to(uint8_t val)
{
return __uint2bfloat16_rn(val);
}
#endif
/********************* To BF162 Conversions *********************/
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE __nv_bfloat162 to(float2 val)
{
return __float22bfloat162_rn(val);
}
template <>
DS_D_INLINE __nv_bfloat162 to(__half2 val)
{
return to<__nv_bfloat162>(to<float2>(val));
}
#endif
/********************* To INT64_T Conversions *********************/
template <>
DS_D_INLINE int64_t to(double val)
{
return __double2ll_rn(val);
}
template <>
DS_D_INLINE int64_t to(float val)
{
return __float2ll_rn(val);
}
template <>
DS_D_INLINE int64_t to(__half val)
{
return __half2ll_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int64_t to(__nv_bfloat16 val)
{
return __bfloat162ll_rn(val);
}
#endif
/********************* To INT32_T Conversions *********************/
template <>
DS_D_INLINE int32_t to(double val)
{
return __double2int_rn(val);
}
template <>
DS_D_INLINE int32_t to(float val)
{
return __float2int_rn(val);
}
template <>
DS_D_INLINE int32_t to(__half val)
{
return __half2int_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int32_t to(__nv_bfloat16 val)
{
return __bfloat162int_rn(val);
}
#endif
/********************* To INT16_T Conversions *********************/
template <>
DS_D_INLINE int16_t to(double val)
{
return __double2int_rn(val);
}
template <>
DS_D_INLINE int16_t to(float val)
{
return __float2int_rn(val);
}
template <>
DS_D_INLINE int16_t to(__half val)
{
return __half2int_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int16_t to(__nv_bfloat16 val)
{
return __bfloat162int_rn(val);
}
#endif
/********************* To INT8_T Conversions *********************/
template <>
DS_D_INLINE int8_t to(double val)
{
return __double2int_rn(val);
}
template <>
DS_D_INLINE int8_t to(float val)
{
return __float2int_rn(val);
}
template <>
DS_D_INLINE int8_t to(__half val)
{
return __half2int_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE int8_t to(__nv_bfloat16 val)
{
return __bfloat162int_rn(val);
}
#endif
/********************* To UINT64_T Conversions *********************/
template <>
DS_D_INLINE uint64_t to(double val)
{
return __double2ull_rn(val);
}
template <>
DS_D_INLINE uint64_t to(float val)
{
return __float2ull_rn(val);
}
template <>
DS_D_INLINE uint64_t to(__half val)
{
return __half2ull_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint64_t to(__nv_bfloat16 val)
{
return __bfloat162ull_rn(val);
}
#endif
/********************* To UINT32_T Conversions *********************/
template <>
DS_D_INLINE uint32_t to(double val)
{
return __double2uint_rn(val);
}
template <>
DS_D_INLINE uint32_t to(float val)
{
return __float2uint_rn(val);
}
template <>
DS_D_INLINE uint32_t to(__half val)
{
return __half2uint_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint32_t to(__nv_bfloat16 val)
{
return __bfloat162uint_rn(val);
}
#endif
/********************* To UINT16_T Conversions *********************/
template <>
DS_D_INLINE uint16_t to(double val)
{
return __double2uint_rn(val);
}
template <>
DS_D_INLINE uint16_t to(float val)
{
return __float2uint_rn(val);
}
template <>
DS_D_INLINE uint16_t to(__half val)
{
return __half2uint_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint16_t to(__nv_bfloat16 val)
{
return __bfloat162uint_rn(val);
}
#endif
/********************* To UINT8_T Conversions *********************/
template <>
DS_D_INLINE uint8_t to(double val)
{
return __double2uint_rn(val);
}
template <>
DS_D_INLINE uint8_t to(float val)
{
return __float2uint_rn(val);
}
template <>
DS_D_INLINE uint8_t to(__half val)
{
return __half2uint_rn(val);
}
// No direct support for integer casts at the C++ level and I don't feel they're so important
// to demand an PTX at this time
#ifdef BF16_AVAILABLE
template <>
DS_D_INLINE uint8_t to(__nv_bfloat16 val)
{
return __bfloat162uint_rn(val);
}
#endif
} // namespace conversion
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#define NOMINMAX // Windows idiosyncrasy
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <cassert>
#include "simd.h"
#if defined(__ENABLE_CUDA__)
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include "cuda.h"
#include "custom_cuda_layers.h"
#include "simd.h"
typedef __half ds_half_precision_t;
#else
typedef unsigned short ds_half_precision_t;
#endif
#define STEP(SPAN) \
void Step_##SPAN(float* _params, \
float* grads, \
float* _exp_avg_sq, \
size_t _param_size, \
__half* dev_param = nullptr, \
ds_half_precision_t* dev_param = nullptr, \
bool half_precision = false);
class Adagrad_Optimizer {
public:
Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0)
: _alpha(alpha), _eps(eps), _weight_decay(weight_decay), _buf_index(false)
: _alpha(alpha), _eps(eps), _weight_decay(weight_decay)
{
#if defined(__ENABLE_CUDA__)
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
_buf_index = false;
#endif
}
~Adagrad_Optimizer()
{
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#endif
}
#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
......@@ -42,16 +57,18 @@ public:
float* grads,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
#endif
STEP(1)
STEP(4)
STEP(8)
#if defined(__ENABLE_CUDA__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#endif
inline void IncrementStep(size_t step)
{
_step++;
......@@ -73,10 +90,11 @@ private:
float _betta2_t;
size_t _step;
float* _doubled_buffer[2];
#if defined(__ENABLE_CUDA__)
bool _buf_index;
float* _doubled_buffer[2];
cudaStream_t _streams[2];
#endif
};
#if defined(__AVX512__) or defined(__AVX256__)
......@@ -86,7 +104,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
float* grads,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t new_rounded_size = 0;
......@@ -104,7 +122,9 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
size_t copy_size = TILE;
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
AVX_Data grad_4[span];
......@@ -128,12 +148,14 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
simd_store<span>(_params + i, param_4, half_precision);
#if defined(__ENABLE_CUDA__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
#endif
simd_store<span>(_exp_avg_sq + i, variance_4, false);
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
if (half_precision)
launch_param_update_half(
......@@ -144,6 +166,7 @@ void Adagrad_Optimizer::Step_AVX(size_t* rounded_size,
_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
}
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#define NOMINMAX // Windows idiosyncrasy
// https://stackoverflow.com/questions/4913922/possible-problems-with-nominmax-on-visual-c
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>
#include <cassert>
#include "simd.h"
#if defined(__ENABLE_CUDA__)
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include "cuda.h"
#include "custom_cuda_layers.h"
#include "simd.h"
typedef __half ds_half_precision_t;
#else
#include <cmath>
typedef unsigned short ds_half_precision_t;
#endif
#define STEP(SPAN) \
void Step_##SPAN(float* _params, \
......@@ -17,7 +28,7 @@
float* _exp_avg, \
float* _exp_avg_sq, \
size_t _param_size, \
__half* dev_param = nullptr, \
ds_half_precision_t* dev_param = nullptr, \
bool half_precision = false);
class Adam_Optimizer {
......@@ -36,20 +47,25 @@ public:
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_buf_index(false),
_adamw_mode(adamw_mode)
{
#if defined(__ENABLE_CUDA__)
cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float));
cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float));
_streams[0] = Context::Instance().GetCurrentStream();
_streams[1] = Context::Instance().GetNewStream();
_buf_index = false;
#endif
}
~Adam_Optimizer()
{
#if defined(__ENABLE_CUDA__)
cudaFreeHost(_doubled_buffer[0]);
cudaFreeHost(_doubled_buffer[1]);
#endif
}
#if defined(__AVX512__) or defined(__AVX256__)
template <int span>
void Step_AVX(size_t* rounded_size,
......@@ -58,16 +74,18 @@ public:
float* _exp_avg,
float* _exp_avg_sq,
size_t param_size,
__half* dev_param = nullptr,
ds_half_precision_t* dev_param = nullptr,
bool half_precision = false);
#endif
STEP(1)
STEP(4)
STEP(8)
#if defined(__ENABLE_CUDA__)
inline void SynchronizeStreams()
{
for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]);
}
#endif
inline void IncrementStep(size_t step, float beta1, float beta2)
{
if (beta1 != _betta1 || beta2 != _betta2) {
......@@ -116,11 +134,13 @@ private:
float _bias_correction1;
float _bias_correction2;
float* _doubled_buffer[2];
bool _buf_index;
bool _adamw_mode;
#if defined(__ENABLE_CUDA__)
float* _doubled_buffer[2];
cudaStream_t _streams[2];
bool _buf_index;
#endif
};
#if defined(__AVX512__) or defined(__AVX256__)
......@@ -131,10 +151,11 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params,
ds_half_precision_t* dev_params,
bool half_precision)
{
size_t new_rounded_size = 0;
int rshft = half_precision ? 1 : 0;
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
......@@ -167,11 +188,13 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
size_t copy_size = TILE;
if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t;
size_t offset = copy_size + t;
#if defined(__ENABLE_CUDA__)
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#endif
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH * span) {
AVX_Data grad_4[span];
simd_load<span>(grad_4, grads + i, half_precision);
simd_load<span>(grad_4, grads + (i >> rshft), half_precision);
AVX_Data momentum_4[span];
simd_load<span>(momentum_4, _exp_avg + i, false);
......@@ -180,7 +203,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_load<span>(variance_4, _exp_avg_sq + i, false);
AVX_Data param_4[span];
simd_load<span>(param_4, _params + i, half_precision);
simd_load<span>(param_4, _params + (i >> rshft), half_precision);
if (_weight_decay > 0 && !_adamw_mode) {
simd_fma<span>(grad_4, param_4, weight_decay4, grad_4);
......@@ -201,14 +224,16 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
simd_fma<span>(param_4, grad_4, step_size_4, param_4);
simd_store<span>(_params + i, param_4, half_precision);
simd_store<span>(_params + (i >> rshft), param_4, half_precision);
#if defined(__ENABLE_CUDA__)
if (dev_params) {
simd_store<span>(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision);
}
#endif
simd_store<span>(_exp_avg + i, momentum_4, false);
simd_store<span>(_exp_avg_sq + i, variance_4, false);
}
#if defined(__ENABLE_CUDA__)
if (dev_params) {
if (half_precision)
launch_param_update_half(
......@@ -219,6 +244,7 @@ void Adam_Optimizer::Step_AVX(size_t* rounded_size,
_buf_index = !_buf_index;
}
#endif
}
*rounded_size = new_rounded_size;
}
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <assert.h>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include "ds_kernel_utils.h"
#include <cuda.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <stdio.h>
#include <stdlib.h>
#ifdef __HIP_PLATFORM_HCC__
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else
#if __CUDA_ARCH__ >= 700
#define HALF_PRECISION_AVAILABLE = 1
#endif
#include <cooperative_groups.h>
#endif
#include <curand_kernel.h>
#include "context.h"
#include "cublas_wrappers.h"
......@@ -45,30 +41,6 @@
#define WARP_SIZE_BITS 5
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
// Fused bias add with gelu activation
template <typename T>
void launch_bias_gelu(const T* input,
......@@ -301,3 +273,54 @@ void launch_fuse_transpose_bias_kernel(const T* inp,
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream);
void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream);
void launch_token_sort(int32_t* indices,
int layers,
int batch_size,
int reserved_size,
int original_tokens,
cudaStream_t stream);
template <typename T>
void launch_gather_tokens(T* retained_tokens,
T* activations,
int32_t* gather_indices,
int32_t batch_size,
int32_t sampled_tokens,
int32_t channels,
int32_t read_batch_stride,
int32_t read_seq_stride,
int32_t write_batch_stride,
int32_t write_seq_stride,
cudaStream_t stream);
template <typename T>
void launch_scatter_tokens(T* all_activations,
T* layer_activations,
int32_t* gather_indices,
int32_t batch_size,
int32_t sampled_tokens,
int32_t channels,
int32_t read_batch_stride,
int32_t read_seq_stride,
int32_t write_batch_stride,
int32_t write_seq_stride,
cudaStream_t stream);
template <typename T>
void launch_slice_gpt_mask(T* output_mask,
const T* input_mask,
int batch_size,
int truncated_seq_len,
int orig_seq_len,
cudaStream_t stream);
template <typename T>
void launch_slice_bert_mask(T* output_mask,
const T* input_mask,
const int32_t* retained_indices,
int32_t layers,
int32_t batch_size,
int32_t truncated_seq_len,
int32_t orig_seq_len,
cudaStream_t stream);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "quantization.h"
#include "quantization_utils.h"
namespace cg = cooperative_groups;
#pragma once
namespace dequantize {
using Type = quantize::Type;
template <Type qType, int numBits>
using Params = quantize::Params<qType, numBits>;
constexpr int granularity = quantize::granularity;
using PackedInt4 = quantize::PackedInt4;
constexpr int h_per_chunk = granularity / sizeof(__half);
constexpr int h2_per_chunk = granularity / sizeof(__half2);
/*
Device function that reads quantized data from global memory, dequantizes
it, and stores it to global memory.
Template Arguments :
numBits - Number of bits in quantized element. int: 4, 8
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
unroll - Number of load steps to internally unroll int
threads - Number of threads to perform dequant int
Function arguments:
global_output - __half pointer in global memory
data - Quantized data in global memory
global_params - Quantization parameters in global memory
elems_per_group - Number of elements in each quantization group
total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
*/
template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems);
/*
Device function that quantizes 16 bytes of __half type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Local array to store dequantized data __half* or __half2*
data - Pointer to quantized input data. int8_t*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params);
template <typename T, int numBits, Type qType>
DS_D_INLINE void chunk(T* local_output, const int8_t* data, Params<qType, numBits> q_params);
/**************** Implementations ******************/
template <typename T, int numBits, Type qType>
DS_D_INLINE void chunk(T* local_output, const int8_t* data, Params<qType, numBits> q_params)
{
constexpr int32_t num_elems_packed = 8 / numBits;
constexpr int32_t iters = h_per_chunk / num_elems_packed;
#pragma unroll
for (int i = 0; i < iters; i++) {
if constexpr (num_elems_packed == 1) {
local_output[i] = q_params.template dequantize<T>(data[i]);
} else {
auto accessible_data = *(PackedInt4*)(&data[i]);
local_output[2 * i] = q_params.template dequantize<T>(accessible_data.low);
local_output[2 * i + 1] = q_params.template dequantize<T>(accessible_data.high);
}
}
}
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params)
{
__half* local_output_cast = reinterpret_cast<__half*>(local_output);
chunk<__half, numBits>(local_output_cast, data, q_params);
}
template <typename T, int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void _to_global(T* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// Load constants
// TODO(cmikeh2): Refactor into functions?
constexpr int load_granularity = (granularity / (sizeof(T))) / (numBits == 8 ? 1 : 2);
constexpr int load_step_stride = load_granularity * threads;
constexpr int load_block_stride = load_step_stride * unroll;
// Store constants
constexpr int T_per_chunk = granularity / sizeof(T);
constexpr int store_step_stride = T_per_chunk * threads;
constexpr int store_block_stride = store_step_stride * unroll;
// Load offsets
const int load_block_offset = tb.group_index().x * load_block_stride;
// Note: we can use `load_granularity` since the dtype is `int8_t`.
const int load_thread_offset = tb.thread_index().x * load_granularity;
const int8_t* load_base = data + load_block_offset + load_thread_offset;
// Store offsets
const int store_block_offset = tb.group_index().x * store_block_stride;
const int store_thread_offset = tb.thread_index().x * T_per_chunk;
const int elem_id_base = store_block_offset + store_thread_offset;
int8_t local_load_buffer[load_granularity * unroll];
T local_dequant_buffer[T_per_chunk * unroll];
/*
Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
totally clear to me, so this is a deliberately weird code structure.
*/
#pragma unroll
for (int i = 0; i < unroll; i++) {
const int elem_id_iter = elem_id_base + i * store_step_stride;
if (elem_id_iter < total_elems) {
mem_access::load_global<load_granularity>(local_load_buffer + i * load_granularity,
load_base + i * load_step_stride);
}
}
#pragma unroll
for (int i = 0; i < unroll; i++) {
const int elem_id_iter = elem_id_base + i * store_step_stride;
if (elem_id_iter < total_elems) {
// TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
// use indexing math to do division free interpolation of the successive groups?
const int group_index = elem_id_iter / elems_per_group;
Params<qType, numBits> q_params(global_params, group_index);
chunk<T, numBits, qType>(local_dequant_buffer + i * T_per_chunk,
local_load_buffer + i * load_granularity,
q_params);
mem_access::store_global<granularity>(global_output + elem_id_iter,
local_dequant_buffer + i * T_per_chunk);
}
}
}
template <typename T, int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void to_global(T* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems)
{
if constexpr (numBits == 4 || numBits == 8) {
_to_global<T, numBits, qType, unroll, threads>(
global_output, data, global_params, elems_per_group, total_elems);
} else if constexpr (numBits == 3) {
// TODO(cmikeh2): Need this implementation
assert(false);
} else {
assert(false);
}
}
} // namespace dequantize
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda.h>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
Centralized header file for preprocessor macros and constants
used throughout the codebase.
*/
#pragma once
#include <cuda.h>
#define DS_HD_INLINE __host__ __device__ __forceinline__
#define DS_D_INLINE __device__ __forceinline__
#ifdef __HIP_PLATFORM_HCC__
// constexpr variant of warpSize for templating
constexpr int hw_warp_size = 64;
#define HALF_PRECISION_AVAILABLE = 1
#include <hip/hip_cooperative_groups.h>
#else // !__HIP_PLATFORM_HCC__
// constexpr variant of warpSize for templating
constexpr int hw_warp_size = 32;
#if __CUDA_ARCH__ >= 530
#define HALF_PRECISION_AVAILABLE = 1
#define PTX_AVAILABLE
#endif // __CUDA_ARCH__ >= 530
#if __CUDA_ARCH__ >= 800
#define ASYNC_COPY_AVAILABLE
#define BF16_AVAILABLE
#endif // __CUDA_ARCH__ >= 800
#include <cooperative_groups.h>
#endif //__HIP_PLATFORM_HCC__
inline int next_pow2(const int val)
{
int rounded_val = val - 1;
rounded_val |= rounded_val >> 1;
rounded_val |= rounded_val >> 2;
rounded_val |= rounded_val >> 4;
rounded_val |= rounded_val >> 8;
return rounded_val + 1;
}
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda_runtime_api.h>
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#ifndef __FEEDFORWARD_H__
#define __FEEDFORWARD_H__
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda.h>
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda.h>
#include "ds_kernel_utils.h"
/////////////////////////////// Memory Access Utils ///////////////////////////////
namespace mem_access {
enum class LoadPolicy {
CacheAll, // Cache at all levels
CacheGlobal, // Cache at L2 only
CacheStreaming // Cache with evict first policy
};
enum class StorePolicy {
Writeback, // Cache in L1, write-back on eviction
CacheGlobal, // Bypass L1, write-back on eviction
CacheStreaming // Allocate cache line with evict first policy
};
template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
__device__ __forceinline__ void load_global(void* dst, const void* src);
template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access);
// Shared accesses have no cache policy
template <int AccessSize>
__device__ __forceinline__ void load_shared(void* dst, const void* src);
template <int AccessSize>
__device__ __forceinline__ void load_shared(void* dst, const void* src, bool do_access);
template <int AccessSize, StorePolicy policy = StorePolicy::Writeback>
__device__ __forceinline__ void store_global(void* dst, const void* src);
// Shared accesses have no cache policy
template <int AccessSize>
__device__ __forceinline__ void store_shared(void* dst, const void* src);
#ifdef ASYNC_COPY_AVAILABLE
template <int AccessSize>
__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl);
template <int AccessSize>
__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate);
template <int AccessSize>
__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate);
__device__ __forceinline__ void memcpy_async_fence();
template <int stages>
__device__ __forceinline__ void memcpy_async_wait();
template <int stages>
__device__ __forceinline__ void tail_complete_wait(int remaining_stages);
#endif
// Util for tracking pipeline buffers
// TODO: Evaluate whether this should also be guarded by ASYNC_COPY_AVAILABLE
template <int max>
class BufferTracker {
public:
int current_state;
__device__ __forceinline__ BufferTracker() : current_state(0) {}
__device__ __forceinline__ int get()
{
int return_val = current_state++;
current_state = (current_state == max ? 0 : current_state);
return return_val;
}
};
__device__ __forceinline__ uint32_t lane_id()
{
#ifdef PTX_AVAILABLE
unsigned int lane_id;
asm volatile("mov.u32 %0, %%laneid;" : "=r"(lane_id));
return lane_id;
#else
return threadIdx.x & (warpSize - 1); // Portable
#endif
}
/////////// Load Global ///////////
template <>
__device__ __forceinline__ void load_global<16>(void* dst, const void* src)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.ca.v4.u32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cs.v4.u32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<8>(void* dst, const void* src)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.ca.v2.u32 {%0, %1}, [%2];\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.global.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cg.v2.u32 {%0, %1}, [%2];\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<4>(void* dst, const void* src)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.ca.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.global.u32 {%0}, [%1];\n"
"}\n"
: "=r"(data[0])
: "l"(src), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cg.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.global.cg.u32 {%0}, [%1];\n"
"}\n"
: "=r"(data[0])
: "l"(src), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cs.u32 {%0}, [%1];\n" : "=r"(*data) : "l"(src));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.global.cs.u32 {%0}, [%1];\n"
"}\n"
: "=r"(data[0])
: "l"(src), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<2>(void* dst, const void* src)
{
int16_t* data = reinterpret_cast<int16_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.ca.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src));
#else
const int16_t* src_cast = reinterpret_cast<const int16_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<2>(void* dst, const void* src, bool do_access)
{
int16_t* data = reinterpret_cast<int16_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.u16 %0, 0;\n"
"\t@p ld.global.u16 {%0}, [%1];\n"
"}\n"
: "=h"(*data)
: "l"(src), "r"((int)do_access));
#else
const int16_t* src_cast = reinterpret_cast<const int16_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
int16_t* data = reinterpret_cast<int16_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cg.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src));
#else
const int16_t* src_cast = reinterpret_cast<const int16_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<2, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
int16_t* data = reinterpret_cast<int16_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.u16 %0, 0;\n"
"\t@p ld.global.cg.u16 {%0}, [%1];\n"
"}\n"
: "=h"(*data)
: "l"(src), "r"((int)do_access));
#else
const int16_t* src_cast = reinterpret_cast<const int16_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
{
int16_t* data = reinterpret_cast<int16_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile("ld.global.cs.u16 {%0}, [%1];\n" : "=h"(*data) : "l"(src));
#else
const int16_t* src_cast = reinterpret_cast<const int16_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_global<2, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
int16_t* data = reinterpret_cast<int16_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.u16 %0, 0;\n"
"\t@p ld.global.cs.u16 {%0}, [%1];\n"
"}\n"
: "=h"(*data)
: "l"(src), "r"((int)do_access));
#else
const int16_t* src_cast = reinterpret_cast<const int16_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
/////////// Load Shared ///////////
namespace internal {
#ifdef PTX_AVAILABLE
__device__ __forceinline__ unsigned convert_to_shared(const void* ptr)
{
#if __CUDACC_VER_MAJOR__ >= 11
// In CUDA 11 we have a builtin intrinsic
return __cvta_generic_to_shared(ptr);
#else
unsigned ret_val;
asm volatile(
"{\n"
"\t.reg .u64 p1;\n"
"\tcvta.to.shared.u64 p1, %1\n"
"\tcvt.u32.u64 %0, p1;\n"
"}\n"
: "=r"(ret_val)
: "l"(ptr));
return ret_val;
#endif
}
#endif
} // namespace internal
template <>
__device__ __forceinline__ void load_shared<16>(void* dst, const void* src)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
unsigned src_shr = internal::convert_to_shared(src);
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "r"(src_shr));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_shared<16>(void* dst, const void* src, bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
unsigned src_shr = internal::convert_to_shared(src);
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "r"(src_shr), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_shared<8>(void* dst, const void* src)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
unsigned src_shr = internal::convert_to_shared(src);
asm volatile("ld.shared.v2.u32 {%0, %1}, [%2];\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "r"(src_shr));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_shared<8>(void* dst, const void* src, bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
unsigned src_shr = internal::convert_to_shared(src);
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.shared.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "r"(src_shr), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_shared<4>(void* dst, const void* src)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
unsigned src_shr = internal::convert_to_shared(src);
asm volatile("ld.shared.u32 {%0}, [%1];\n" : "=r"(*data) : "r"(src_shr));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
data[0] = src_cast[0];
#endif
}
template <>
__device__ __forceinline__ void load_shared<4>(void* dst, const void* src, bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
unsigned src_shr = internal::convert_to_shared(src);
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.shared.u32 %0, [%1];\n"
"}\n"
: "=r"(data[0])
: "r"(src_shr), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
/////////// Store Global ///////////
template <>
__device__ __forceinline__ void store_global<16>(void* dst, const void* src)
{
const uint4* data = reinterpret_cast<const uint4*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.wb.v4.u32 [%0], {%1, %2, %3, %4};\n"
:
: "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)
: "memory");
#else
uint4* dst_cast = reinterpret_cast<uint4*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<16, StorePolicy::CacheGlobal>(void* dst,
const void* src)
{
const uint4* data = reinterpret_cast<const uint4*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.cg.v4.u32 [%0], {%1, %2, %3, %4};\n"
:
: "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)
: "memory");
#else
uint4* dst_cast = reinterpret_cast<uint4*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<16, StorePolicy::CacheStreaming>(void* dst,
const void* src)
{
const uint4* data = reinterpret_cast<const uint4*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n"
:
: "l"(dst), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w)
: "memory");
#else
uint4* dst_cast = reinterpret_cast<uint4*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<8>(void* dst, const void* src)
{
const uint2* data = reinterpret_cast<const uint2*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.wb.v2.u32 [%0], {%1, %2};\n"
:
: "l"(dst), "r"(data[0].x), "r"(data[0].y));
#else
uint2* dst_cast = reinterpret_cast<uint2*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<8, StorePolicy::CacheGlobal>(void* dst,
const void* src)
{
const uint2* data = reinterpret_cast<const uint2*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.cg.v2.u32 [%0], {%1, %2};\n"
:
: "l"(dst), "r"(data[0].x), "r"(data[0].y));
#else
uint2* dst_cast = reinterpret_cast<uint2*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<8, StorePolicy::CacheStreaming>(void* dst,
const void* src)
{
const uint2* data = reinterpret_cast<const uint2*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.cs.v2.u32 [%0], {%1, %2};\n"
:
: "l"(dst), "r"(data[0].x), "r"(data[0].y));
#else
uint2* dst_cast = reinterpret_cast<uint2*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<4>(void* dst, const void* src)
{
const int32_t* data = reinterpret_cast<const int32_t*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.wb.u32 [%0], %1;\n" : : "l"(dst), "r"(*data));
#else
int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<4, StorePolicy::CacheGlobal>(void* dst,
const void* src)
{
const int32_t* data = reinterpret_cast<const int32_t*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.cg.u32 [%0], %1;\n" : : "l"(dst), "r"(*data));
#else
int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_global<4, StorePolicy::CacheStreaming>(void* dst,
const void* src)
{
const int32_t* data = reinterpret_cast<const int32_t*>(src);
#ifdef PTX_AVAILABLE
asm volatile("st.global.cs.u32 [%0], %1;\n" : : "l"(dst), "r"(*data));
#else
int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
dst_cast[0] = data[0];
#endif
}
/////////// Store Shared ///////////
template <>
__device__ __forceinline__ void store_shared<16>(void* dst, const void* src)
{
const uint4* data = reinterpret_cast<const uint4*>(src);
#ifdef PTX_AVAILABLE
unsigned dst_int = internal::convert_to_shared(dst);
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(dst_int), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z), "r"(data[0].w));
#else
uint4* dst_cast = reinterpret_cast<uint4*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_shared<8>(void* dst, const void* src)
{
const uint2* data = reinterpret_cast<const uint2*>(src);
#ifdef PTX_AVAILABLE
unsigned dst_int = internal::convert_to_shared(dst);
asm volatile("st.shared.v2.u32 [%0], {%1, %2};\n"
:
: "r"(dst_int), "r"(data[0].x), "r"(data[0].y));
#else
uint2* dst_cast = reinterpret_cast<uint2*>(dst);
dst_cast[0] = data[0];
#endif
}
template <>
__device__ __forceinline__ void store_shared<4>(void* dst, const void* src)
{
const int32_t* data = reinterpret_cast<const int32_t*>(src);
#ifdef PTX_AVAILABLE
unsigned dst_int = internal::convert_to_shared(dst);
asm volatile("st.shared.u32 [%0], %1;\n" : : "r"(dst_int), "r"(*data));
#else
int32_t* dst_cast = reinterpret_cast<int32_t*>(dst);
dst_cast[0] = data[0];
#endif
}
/////////// Asynchronous Memory Copy ///////////
#ifdef ASYNC_COPY_AVAILABLE
template <int AccessSize>
__device__ __forceinline__ void memcpy_async(void* shr, const void* gbl)
{
static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
unsigned shr_int = internal::convert_to_shared(shr);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n"
:
: "r"(shr_int), "l"(gbl), "n"(AccessSize));
}
template <int AccessSize>
__device__ __forceinline__ void memcpy_async_nop(void* shr, const void* gbl, bool predicate)
{
static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
unsigned shr_int = internal::convert_to_shared(shr);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3;\n"
"}\n"
:
: "r"((int)predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize));
}
template <int AccessSize>
__device__ __forceinline__ void memcpy_async_zero(void* shr, const void* gbl, bool predicate)
{
static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
unsigned shr_int = internal::convert_to_shared(shr);
int bytes_to_copy = (predicate ? AccessSize : 0);
asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n"
:
: "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy));
}
template <int AccessSize>
__device__ __forceinline__ void memcpy_async_zero_nop(void* shr,
const void* gbl,
bool zero_predicate,
bool nop_predicate)
{
static_assert((AccessSize == 4 || AccessSize == 8 || AccessSize == 16));
unsigned shr_int = internal::convert_to_shared(shr);
int bytes_to_copy = (zero_predicate ? AccessSize : 0);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global [%1], [%2], %3, %4;\n"
"}\n"
:
: "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "n"(AccessSize), "r"(bytes_to_copy));
}
// Cache global variants. Separate interface to require deliberate use of them.
__device__ __forceinline__ void memcpy_async_cg(void* shr, const void* gbl)
{
unsigned shr_int = internal::convert_to_shared(shr);
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" : : "r"(shr_int), "l"(gbl));
}
__device__ __forceinline__ void memcpy_async_nop_cg(void* shr, const void* gbl, bool predicate)
{
unsigned shr_int = internal::convert_to_shared(shr);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], 16;\n"
"}\n"
:
: "r"((int)predicate), "r"(shr_int), "l"(gbl));
}
__device__ __forceinline__ void memcpy_async_zero_cg(void* shr, const void* gbl, bool predicate)
{
unsigned shr_int = internal::convert_to_shared(shr);
int bytes_to_copy = (predicate ? 16 : 0);
asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n"
:
: "r"(shr_int), "l"(gbl), "r"(bytes_to_copy));
}
__device__ __forceinline__ void memcpy_async_zero_nop_cg(void* shr,
const void* gbl,
bool zero_predicate,
bool nop_predicate)
{
unsigned shr_int = internal::convert_to_shared(shr);
int bytes_to_copy = (zero_predicate ? 16 : 0);
asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.cg.shared.global [%1], [%2], 16, %3;\n"
"}\n"
:
: "r"((int)nop_predicate), "r"(shr_int), "l"(gbl), "r"(bytes_to_copy));
}
__device__ __forceinline__ void memcpy_async_fence() { asm volatile("cp.async.commit_group;\n"); }
template <int stages>
__device__ __forceinline__ void memcpy_async_wait()
{
static_assert(stages <= 8);
asm volatile("cp.async.wait_group %0;\n" : : "n"(stages));
}
// TODO: The tail complete should be a known compile time artifact, should try and induce this
// without all of the branches from the call-site. This is a hacky solution.
template <>
__device__ __forceinline__ void tail_complete_wait<1>(int remaining_stages)
{
if (remaining_stages == 0) memcpy_async_wait<0>();
}
template <>
__device__ __forceinline__ void tail_complete_wait<2>(int remaining_stages)
{
if (remaining_stages == 1)
memcpy_async_wait<1>();
else if (remaining_stages == 0)
memcpy_async_wait<0>();
}
template <>
__device__ __forceinline__ void tail_complete_wait<3>(int remaining_stages)
{
if (remaining_stages == 2)
memcpy_async_wait<2>();
else if (remaining_stages == 1)
memcpy_async_wait<1>();
else if (remaining_stages == 0)
memcpy_async_wait<0>();
}
template <>
__device__ __forceinline__ void tail_complete_wait<4>(int remaining_stages)
{
if (remaining_stages == 3)
memcpy_async_wait<3>();
else if (remaining_stages == 2)
memcpy_async_wait<2>();
else if (remaining_stages == 1)
memcpy_async_wait<1>();
else if (remaining_stages == 0)
memcpy_async_wait<0>();
}
template <>
__device__ __forceinline__ void tail_complete_wait<5>(int remaining_stages)
{
if (remaining_stages == 4)
memcpy_async_wait<4>();
else if (remaining_stages == 3)
memcpy_async_wait<3>();
else if (remaining_stages == 2)
memcpy_async_wait<2>();
else if (remaining_stages == 1)
memcpy_async_wait<1>();
else if (remaining_stages == 0)
memcpy_async_wait<0>();
}
template <>
__device__ __forceinline__ void tail_complete_wait<6>(int remaining_stages)
{
if (remaining_stages == 5)
memcpy_async_wait<5>();
else if (remaining_stages == 4)
memcpy_async_wait<4>();
else if (remaining_stages == 3)
memcpy_async_wait<3>();
else if (remaining_stages == 2)
memcpy_async_wait<2>();
else if (remaining_stages == 1)
memcpy_async_wait<1>();
else if (remaining_stages == 0)
memcpy_async_wait<0>();
}
#endif
} // namespace mem_access
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda.h>
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda_fp16.h>
#include "ds_kernel_utils.h"
namespace quantize {
enum class Type { Symmetric, Asymmetric };
struct PackedInt4 {
int8_t high : 4;
int8_t low : 4;
};
DS_HD_INLINE bool requires_offset(Type qType) { return qType == Type::Asymmetric; }
} // namespace quantize
void launch_quant(int8_t* output_data,
float* params,
const __half* input_data,
const int groups,
const int elems_per_group,
const int num_bits,
const quantize::Type quant_type,
cudaStream_t stream);
template <typename T>
void launch_dequantize_kernel(T* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream);
template <typename T>
void launch_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <cassert>
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "quantization.h"
#include "reduction_utils.h"
#pragma once
using rop = reduce::ROpType;
namespace quantize {
constexpr int granularity = 16;
constexpr int h_per_load = granularity / sizeof(__half);
constexpr int h2_per_load = granularity / sizeof(__half2);
constexpr int max_threads = 1024;
/*
Class to hold the quantization parameters for a given tensor.
Holds the implementation of the quantization operation.
*/
template <Type qType, int numBits>
class Params {
public:
/*
Quantization implementation, supports
1) 4 Bit
2) 8 Bit
3) Symmetric
4) Asymmetric
Function Arguments :
val : The __half value to quantize.
*/
DS_D_INLINE int8_t quantize(__half val);
template <typename T>
DS_D_INLINE T dequantize(int8_t val);
DS_D_INLINE void store(float* params, int group_index);
// Initialize from memory
DS_D_INLINE Params(const float* params, int group_index);
};
template <int numBits>
class Params<Type::Symmetric, numBits> {
public:
float scale;
DS_D_INLINE Params(float max)
{
if (max == 0) {
scale = 1.0;
} else {
scale = (1 << numBits) / (2 * max);
}
}
DS_D_INLINE int8_t quantize(__half val)
{
constexpr int32_t q_min = -(1 << (numBits - 1));
constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
float val_f = conversion::to<float>(val) * scale;
int32_t data_i32 = conversion::to<int32_t>(val_f);
data_i32 = min(max(data_i32, q_min), q_max);
return (int8_t)data_i32;
}
template <typename T>
DS_D_INLINE T dequantize(int8_t val)
{
const float val_deq_f = conversion::to<float>(val) * scale;
return conversion::to<T>(val_deq_f);
}
DS_D_INLINE void store(float* params, int group_index)
{
const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + group_index, &store_scale);
}
DS_D_INLINE Params(const float* params, int group_index)
{
mem_access::load_global<sizeof(float)>(&scale, params + group_index);
}
};
template <int numBits>
class Params<Type::Asymmetric, numBits> {
public:
float scale;
float offset;
DS_D_INLINE Params(float max, float min)
{
if (max == min) {
scale = 1.0;
} else {
scale = (1 << numBits) / (max - min);
}
offset = -(1 << (numBits - 1)) - (min * scale);
}
DS_D_INLINE int8_t quantize(__half val)
{
constexpr int32_t q_min = -(1 << (numBits - 1));
constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
float val_f = conversion::to<float>(val) * scale + offset;
int32_t data_i32 = conversion::to<int32_t>(val_f);
data_i32 = min(max(data_i32, q_min), q_max);
return (int8_t)data_i32;
}
template <typename T>
DS_D_INLINE T dequantize(int8_t val)
{
const float val_deq_f = conversion::to<float>(val) * scale + offset;
return conversion::to<__half>(val_deq_f);
}
DS_D_INLINE void store(float* params, int group_index)
{
// Codegen should turn this into stg.64
const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + 2 * group_index, &store_scale);
mem_access::store_global<sizeof(float)>(params + 2 * group_index + 1, &offset);
}
DS_D_INLINE Params(const float* params, int group_index)
{
// Codegen should turn this into ldg.64
mem_access::load_global<sizeof(float)>(&scale, params + 2 * group_index);
mem_access::load_global<sizeof(float)>(&offset, params + 2 * group_index + 1);
}
};
/*
Group stats tracks the necessary statistics about the quantized group
to abstract the particulars for the main loop.
*/
template <Type qType>
class GroupStats {
public:
DS_D_INLINE void update(__half2 val);
DS_D_INLINE void reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp);
};
template <>
class GroupStats<Type::Symmetric> {
public:
// Symmetric quantization only tracks the maximum absolute value
__half2 cur_max;
float max;
/*
Technically, this would give bad results if there
are 0 values to process since the reduction would
give -inf instead of 0. We do not consider this
to be a reasonable edge case.
*/
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, __half2>(); }
/*
Updated the running absmax used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, __habs2(val));
}
/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Symmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
Params<Type::Symmetric, numBits> params(max);
return params;
}
};
template <>
class GroupStats<Type::Asymmetric> {
public:
__half2 cur_max;
__half2 cur_min;
/*
Initialize cur_max to -inf, cur_min to inf since
we are doing a true range analysis.
*/
DS_D_INLINE GroupStats()
{
cur_max = reduce::init<rop::Max, __half2>();
cur_min = reduce::init<rop::Min, __half2>();
}
/*
Updated the running min and max used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, val);
cur_min = reduce::element<rop::Min>(cur_min, val);
}
/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Asymmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
const float2 partial_min = conversion::to<float2>(cur_min);
float min = reduce::element<rop::Min>(partial_min.x, partial_min.y);
reduce::partitioned_block<rop::Max, rop::Min, threads_per_group>(tb, warp, max, min);
Params<Type::Asymmetric, numBits> params(max, min);
return params;
}
};
/*
Device function that quantizes 16 bytes of __half type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Pointer to local memory to store quantized data. int8_t*
data - Pointer to input data. __half*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half* data, Params<qType, numBits> q_params);
/*
Device function that quantizes 16 bytes of __half2 type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Pointer to local memory to store quantized data. int8_t*
data - Pointer to input data. __half2*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half2* data, Params<qType, numBits> q_params);
/*
Helper function to do serial reduction on register-file arrays.
Template Arguments :
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
numChunks - Number of bits in quantized element. int : 8 or 4
Function Arguments :
local_buffer - Pointer memory with input half2 data to be quantized.
*/
template <Type qType, int numChunks>
DS_D_INLINE GroupStats<qType> _local_serial_reduce(__half2* local_buffer);
/*
The main loop of the kernel that quantizes array in local memory of __half2 type input data, when
Quantization parameters are pre-computed.
Template Arguments :
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
numBits - Number of bits in quantized element. int : 8 or 4
numChunks - Number of chunks(16 bytes of Input data). int : 8 or 4
Function Arguments :
local_buffer - Pointer memory with input half2 data to be quantized.
scales - Pointer to output scales.
offsets - Pointer to output offsets.
output_data - Pointer to output data.
elems_per_group - Number of elements to quantize in a group.
q_params - Quantization parameters.
*/
template <int numBits, Type qType, int numChunks, int threads_per_group, int max_threads>
DS_D_INLINE void local_array(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
__half2* local_buffer,
float* __restrict__ scales,
float* __restrict__ offsets,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups,
Params<qType, numBits> q_params);
/*
The main loop of the kernel that quantizes array in local memory of __half2 type input data.
This function computes quantization parameters for each group.
Template Arguments :
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
numBits - Number of bits in quantized element. int : 8 or 4
numChunks - Number of chunks(16 bytes of Input data). int : 8 or 4
Function Arguments :
local_buffer - Pointer memory with input half2 data to be quantized.
scales - Pointer to output scales.
offsets - Pointer to output offsets.
output_data - Pointer to output data.
elems_per_group - Number of elements to quantize in a group.
*/
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
__device__ void local_array(__half2* local_buffer,
float* __restrict__ scales,
float* __restrict__ offsets,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups);
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half* data, Params<qType, numBits> q_params)
{
constexpr int32_t elems = 16 / sizeof(__half);
constexpr int32_t num_elems_packed = 8 / numBits;
#pragma unroll
for (int i = 0, oi = 0; i < elems; i += num_elems_packed, oi++) {
if (num_elems_packed == 1) {
// TODO(cmikeh2): refactor to use conversion utils
local_output[i] = q_params.quantize(data[i]);
} else if (num_elems_packed == 2) {
int8_t data_i8_1 = q_params.quantize(data[i]);
int8_t data_i8_2 = q_params.quantize(data[i + 1]);
auto data_i8 = PackedInt4{data_i8_2, data_i8_1};
local_output[oi] = *((int8_t*)(&data_i8));
}
}
}
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half2* data, Params<qType, numBits> q_params)
{
const __half* data_cast = reinterpret_cast<const __half*>(data);
_chunk<numBits>(local_output, data_cast, q_params);
}
template <Type qType, int numChunks>
DS_D_INLINE GroupStats<qType> _local_serial_reduce(__half2* local_buffer)
{
GroupStats<qType> stats;
#pragma unroll
for (int i = 0; i < numChunks * h2_per_load; i++) { stats.update(local_buffer[i]); }
return stats;
}
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
DS_D_INLINE void local_array(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
__half2* local_buffer,
float* __restrict__ global_params,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups,
Params<qType, numBits> q_params)
{
constexpr int num_ele_int8 = 8 / numBits;
constexpr int num_int8_out = quantize::h_per_load / num_ele_int8;
// Indexing offsets
const int block_num =
(tb.group_index().x * max_threads / threads_per_group) + tb.thread_index().y;
const int block_offset = block_num * elems_per_group;
const int elem_offset = tb.thread_index().x * quantize::h_per_load;
const int base_offset = (block_offset + elem_offset) / num_ele_int8;
const int stride = tb.size() * quantize::h_per_load / num_ele_int8;
int8_t local_output[num_int8_out];
if (tb.thread_index().x == 0 && block_num < groups) {
q_params.store(
global_params,
(tb.group_index().x * max_threads / threads_per_group) + tb.thread_index().y);
}
#pragma unroll
for (int i = 0; i < numChunks; i++) {
if (elem_offset + i * stride * num_ele_int8 < elems_per_group && block_num < groups) {
quantize::_chunk<numBits, qType>(
local_output, local_buffer + i * quantize::h2_per_load, q_params);
mem_access::store_global<num_int8_out>(output_data + (base_offset + i * stride),
local_output);
}
}
}
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
DS_D_INLINE void local_array(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
__half* local_buffer,
float* __restrict__ global_params,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups,
Params<qType, numBits> q_params)
{
__half2* local_buffer_h2 = reinterpret_cast<__half2*>(local_buffer);
quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
tb, warp, local_buffer, global_params, output_data, elems_per_group, groups, q_params);
}
template <Type qType,
int numBits,
int numChunks,
int threads_per_group = max_threads,
int max_threads = 256>
__device__ void local_array(__half2* local_buffer,
float* __restrict__ global_params,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
auto group_stats = _local_serial_reduce<qType, numChunks>(local_buffer);
auto params = group_stats.template get_params<numBits, threads_per_group>(tb, warp);
quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
tb, warp, local_buffer, global_params, output_data, elems_per_group, groups, params);
}
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
__device__ void local_array(__half* local_buffer,
float* __restrict__ global_params,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups)
{
__half2* local_buffer_h2 = reinterpret_cast<__half2*>(local_buffer);
quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
local_buffer_h2, global_params, output_data, elems_per_group, groups);
}
} // namespace quantize
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cooperative_groups.h>
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
namespace reduce {
enum class ROpType {
// Addition
Add,
// Maximum reduction
Max,
// Minimum reduction
Min,
};
constexpr int max_threads = 1024;
constexpr int max_warps = max_threads / hw_warp_size;
/*
High level API. The API takes in a set of operations and variables
and performs that reduction operation on that variable. The reductions
of each of the arguments are completely independent of each other (
i.e., the val1-op1 combination has no impact on val2-op2).
Example usage:
``` cpp
float max_val;
float min_val;
reduce::block<rop::Max, rop::Min>(tb, warp, max_val, min_val);
```
TODO(cmikeh2): In theory, we might be able to do this sequentially with
device functions and rely on the assembler correctly behaving. My initial
instinct is this won't work, but if it does it would reduce implementation
cost significantly.
TODO(cmikeh2): We need to support sub-block reductions. The warp intrinsic
currently supports this (more incidentally than anything else). It is not
uncommon in something like softmax or a fused attention kernel to map multiple
reductions to a thread block, but each reduction itself is only scoped
to part of the threads (i.e block size = 512, 128 threads per reduction).
*/
template <ROpType Op, int warp_bound = max_warps>
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val);
template <ROpType Op1, ROpType Op2, int warp_bound = max_warps>
DS_D_INLINE void block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2);
template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound = max_warps>
DS_D_INLINE void block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3);
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound = max_warps>
DS_D_INLINE void block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3,
float& val4);
/*
The partitioned block is a special case of the above where in the warps of a threadblock are
partitioned into separate independent reductions. For example, I might have an 8 warp thread block
in which each pair of warps is processing an independent piece of data. I would then reduce that
data with the something like the following:
``` cpp
float max_val;
reduce::partitioned_block<rop::Max, 2>(tb, warp, max_val);
```
After which, each pair of warps would have coherent data with each other. Note, this API will not
provide correct results if the number of warps per partition is not a power of 2.
*/
template <ROpType Op, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val);
template <ROpType Op1, ROpType Op2, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2);
template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3);
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3,
float& val4);
/*
Single element reduction primitives. Used inside serial collection
loops.
Example usage:
using rop = reduce::OpType;
float min = init<rop::Min>();
for (int i = 0; i < 4; i++) {
min = reduce::element<rop::Min>(min, data[i]);
}
*/
template <ROpType Op, typename T>
DS_D_INLINE T element(const T lhs, const T rhs);
template <ROpType OType, typename T = float>
DS_D_INLINE T init();
/********************** Internal reduction APIs **********************/
/*
Single element "reductions". TODO(cmikeh2): this sort of "op" concept
should be refactored into its own implementation at some point. This interface
may be easily expanded for new types/operations, but the typical reductions
we need are covered with min/max/add on float.
NOTE: there is no mean reduction because that relies on knowledge of how
many values were already reduced into each scalar. Implementing this on top
of reduce should be straightforward (can just wrap the sum reduction) and
would be a good extension of the header.
*/
/* Float element reduce implementations */
template <>
DS_D_INLINE float element<ROpType::Add>(const float lhs, const float rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE float element<ROpType::Max>(const float lhs, const float rhs)
{
return fmaxf(lhs, rhs);
}
template <>
DS_D_INLINE float element<ROpType::Min>(const float lhs, const float rhs)
{
return fminf(lhs, rhs);
}
/* __half element reduce implementation */
template <>
DS_D_INLINE __half element<ROpType::Add>(const __half lhs, const __half rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE __half element<ROpType::Max>(const __half lhs, const __half rhs)
{
#if __CUDA_ARCH__ >= 800
// Intrinsic limited to Ampere + newer
return __hmax(lhs, rhs);
#else
return (lhs > rhs) ? lhs : rhs;
#endif
}
template <>
DS_D_INLINE __half element<ROpType::Min>(const __half lhs, const __half rhs)
{
#if __CUDA_ARCH__ >= 800
// Intrinsic limited to Ampere + newer
return __hmin(lhs, rhs);
#else
return (lhs < rhs) ? lhs : rhs;
#endif
}
/* __half2 element reduce implementation */
template <>
DS_D_INLINE __half2 element<ROpType::Add>(const __half2 lhs, const __half2 rhs)
{
return lhs + rhs;
}
template <>
DS_D_INLINE __half2 element<ROpType::Max>(const __half2 lhs, const __half2 rhs)
{
#if __CUDA_ARCH__ >= 800
return __hmax2(lhs, rhs);
#else
__half2 ret_val;
ret_val.x = (lhs.x > rhs.x) ? lhs.x : rhs.x;
ret_val.y = (lhs.y > rhs.y) ? lhs.y : rhs.y;
return ret_val;
#endif
}
template <>
DS_D_INLINE __half2 element<ROpType::Min>(const __half2 lhs, const __half2 rhs)
{
#if __CUDA_ARCH__ >= 800
return __hmin2(lhs, rhs);
#else
__half2 ret_val;
ret_val.x = (lhs.x < rhs.x) ? lhs.x : rhs.x;
ret_val.y = (lhs.y < rhs.y) ? lhs.y : rhs.y;
return ret_val;
#endif
}
/*
Reduction initialization primitives
*/
template <>
DS_D_INLINE float init<ROpType::Add>()
{
return 0.0f;
}
template <>
DS_D_INLINE float init<ROpType::Min>()
{
// Positive infinity
return INFINITY;
}
template <>
DS_D_INLINE float init<ROpType::Max>()
{
// Negative infinity
return -INFINITY;
}
template <>
DS_D_INLINE __half init<ROpType::Add>()
{
constexpr __half_raw zero = {0x0000};
return __half(zero);
}
template <>
DS_D_INLINE __half init<ROpType::Min>()
{
constexpr __half_raw inf = {0x7C00};
return __half(inf);
}
template <>
DS_D_INLINE __half init<ROpType::Max>()
{
constexpr __half_raw neg_inf = {0xFC00};
return __half(neg_inf);
}
template <>
DS_D_INLINE __half2 init<ROpType::Add>()
{
constexpr __half2_raw zero = {0x0000, 0x0000};
return __half2(zero);
}
template <>
DS_D_INLINE __half2 init<ROpType::Min>()
{
constexpr __half2_raw inf = {0x7C00, 0x7C00};
return __half2(inf);
}
template <>
DS_D_INLINE __half2 init<ROpType::Max>()
{
constexpr __half2_raw neg_inf = {0xFC00, 0xFC00};
return __half2(neg_inf);
}
template <ROpType Op, typename T>
DS_D_INLINE void init(T* data)
{
data[0] = init<Op, T>();
}
template <ROpType Op1, ROpType Op2, typename T>
DS_D_INLINE void init(T* data)
{
data[0] = init<Op1, T>();
data[1] = init<Op2, T>();
}
template <ROpType Op1, ROpType Op2, ROpType Op3, typename T>
DS_D_INLINE void init(T* data)
{
data[0] = init<Op1, T>();
data[1] = init<Op2, T>();
data[2] = init<Op3, T>();
}
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, typename T>
DS_D_INLINE void init(T* data)
{
data[0] = init<Op1, T>();
data[1] = init<Op2, T>();
data[2] = init<Op3, T>();
data[3] = init<Op4, T>();
}
/*
Warp reduction primitives
`reduction_width` is an unsafe template parameter, that is that
when using `reduction_width` < hw_warp_size the warp is partitioned
into `hw_warp_size` / `reduction_width` groups of partial sums.
If someone can figure out how to use variadic templates in a reasonable way
here (fold is C++17 only and I don't think helps and recursion feels like
huge overkill that harms readability) that would be wonderful.
*/
template <ROpType Op, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
data[0] = element<Op>(data[0], warp.shfl_xor(data[0], i));
}
}
template <ROpType Op1, ROpType Op2, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
}
}
template <ROpType Op1, ROpType Op2, ROpType Op3, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
}
}
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int reduce_width = hw_warp_size>
DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, float* data)
{
#pragma unroll
for (int i = 1; i < reduce_width; i *= 2) {
data[0] = element<Op1>(data[0], warp.shfl_xor(data[0], i));
data[1] = element<Op2>(data[1], warp.shfl_xor(data[1], i));
data[2] = element<Op3>(data[2], warp.shfl_xor(data[2], i));
data[3] = element<Op4>(data[3], warp.shfl_xor(data[3], i));
}
}
/*
Implementation for primary block reduction that serves both `block` and
`partitioned_block`.
`local_warp_rank` refers to the warp's location within the partition, so
for an unpartitioned threadblock this will be equivalent to
`warp_arg.meta_group_rank()`.
Similarly, the warp offset is the `local_warp_rank` of the warp with the
lowest rank in the partition. In the case of an 8 warp block with a
4 warp reduction, this would map to [0, 0, 0, 0, 4, 4, 4, 4].
Partition size is the number of warps per partition (equal to the thread
block in the default case). This enables us to only perform the warp reduction
when able to.
*/
template <int total_warps, ROpType... Ops>
DS_D_INLINE void _block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp_arg,
float* data,
int warp_offset)
{
constexpr int elems = sizeof...(Ops);
// Separated for now in case this no longer is true
constexpr int bytes = sizeof(float);
// Unused when `partition_size == 1` or total_warps == 1
__shared__ float reduce_buffer[max_warps * elems];
// Always perform warp-scope reduction
_warp<Ops...>(warp_arg, data);
// If max_warps == 1 let's skip the runtime check
if (warp_arg.meta_group_size() > 1 && total_warps != 1) {
if (warp_arg.thread_rank() == 0) {
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::store_shared<bytes>(
reduce_buffer + elems * warp_arg.meta_group_rank() + i, data + i);
}
}
// Synchronization inside block-uniform conditional is safe
tb.sync();
if (warp_arg.meta_group_rank() == 0) {
if (warp_arg.thread_rank() < warp_arg.meta_group_size()) {
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::load_shared<bytes>(
data + i, reduce_buffer + elems * warp_arg.thread_rank() + i);
}
} else {
init<Ops...>(data);
}
_warp<Ops..., total_warps>(warp_arg, data);
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::store_shared<bytes>(reduce_buffer + elems * warp_arg.thread_rank() + i,
data + i);
}
}
// Synchronization inside block-uniform conditional is safe
tb.sync();
#pragma unroll
for (int i = 0; i < elems; i++) {
mem_access::load_shared<bytes>(data + i,
reduce_buffer + warp_arg.meta_group_rank() * elems + i);
}
}
}
/*
Main API implementations. For the most part, they just convert the individual
variables into arrays, which makes working with them easier with a single
implementation. In theory, we could use the `_block` implementation as another
option, but the nature of using a pointer is a little less safe and this allows
us to obfuscate the details of the partitioned implementation.
*/
template <ROpType Op, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp, float& val)
{
_block<warp_bound, Op>(tb, warp, &val, 0);
}
template <ROpType Op1, ROpType Op2, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2)
{
float data[2] = {val1, val2};
_block<warp_bound, Op1, Op2>(tb, warp, data, 0);
val1 = data[0];
val2 = data[1];
}
template <ROpType Op1, ROpType Op2, ROpType Op3, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3)
{
float data[3] = {val1, val2, val3};
_block<warp_bound, Op1, Op2, Op3>(tb, warp, data, 0);
val1 = data[0];
val2 = data[1];
val3 = data[2];
}
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int warp_bound>
DS_D_INLINE void block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3,
float& val4)
{
float data[4] = {val1, val2, val3, val4};
_block<warp_bound, Op1, Op2, Op3, Op4>(tb, warp, data, 0);
val1 = data[0];
val2 = data[1];
val3 = data[2];
val4 = data[3];
}
/*
Note: for the partitioned blocks, the implementation does not support non-power of 2 blocks in order
to shorten block scale reduction length.
*/
template <ROpType Op, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val)
{
if (num_threads <= hw_warp_size) {
_warp<Op, num_threads>(warp, &val);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op>(tb, warp, &val, warp_offset);
}
}
template <ROpType Op1, ROpType Op2, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2)
{
float data[2] = {val1, val2};
if (num_threads <= hw_warp_size) {
_warp<Op1, Op2, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2>(tb, warp, data, warp_offset);
}
val1 = data[0];
val2 = data[1];
}
template <ROpType Op1, ROpType Op2, ROpType Op3, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3)
{
float data[3] = {val1, val2, val3};
if (num_threads <= hw_warp_size) {
_warp<Op1, Op2, Op3, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2, Op3>(tb, warp, data, warp_offset);
}
val1 = data[0];
val2 = data[1];
val3 = data[2];
}
template <ROpType Op1, ROpType Op2, ROpType Op3, ROpType Op4, int num_threads>
DS_D_INLINE void partitioned_block(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
float& val1,
float& val2,
float& val3,
float& val4)
{
float data[4] = {val1, val2, val3, val4};
if (num_threads <= hw_warp_size) {
_warp<Op1, Op2, Op3, Op4, num_threads>(warp, data);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op1, Op2, Op3, Op4>(tb, warp, data, warp_offset);
}
val1 = data[0];
val2 = data[1];
val3 = data[2];
val4 = data[3];
}
} // namespace reduce
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#if (__x86_64__ || __i386__)
......@@ -22,7 +26,7 @@
#define SIMD_WIDTH 16
#define SIMD_LOAD2(x, h) \
((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x))
((h) ? _mm512_cvtph_ps(_mm256_castps_si256(_mm256_loadu_ps(x))) : _mm512_loadu_ps(x))
#define SIMD_STORE2(x, d, h) \
((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \
: _mm512_storeu_ps(x, d))
......@@ -60,18 +64,16 @@ union AVX_Data {
template <int span>
inline void simd_store(float* dst, AVX_Data* src, bool half_precision)
{
size_t width = (half_precision ? SIMD_WIDTH / 2 : SIMD_WIDTH);
#pragma unroll
for (size_t i = 0; i < span; ++i) {
SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision);
}
for (size_t i = 0; i < span; ++i) { SIMD_STORE2(dst + width * i, src[i].data, half_precision); }
}
template <int span>
inline void simd_load(AVX_Data* dst, float* src, bool half_precision)
{
size_t width = (half_precision ? 1 : SIMD_WIDTH);
#pragma unroll
for (size_t i = 0; i < span; ++i) {
dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision);
}
for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_LOAD2(src + width * i, half_precision); }
}
template <int span>
inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment