Commit 67ea635f authored by aiss's avatar aiss
Browse files

push dsv0.8.2 version

parent 1b2721ad
Pipeline #201 failed with stages
in 0 seconds
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda.h>
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#pragma once
#include <cuda.h>
......
/*
Copyright The Microsoft DeepSpeed Team
*/
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
......
......@@ -7,12 +7,12 @@
#include "ATen/TensorUtils.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
// #include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
// #include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
......@@ -105,7 +105,7 @@ __device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
#if (__CUDA_ARCH__ >= 300) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 502)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "dequantization_utils.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
template <typename T, int numBits, dequantize::Type qType, int unroll, int threads>
__global__ void dequantize_kernel(T* __restrict__ dequant_data,
const int8_t* __restrict__ q_data,
const float* __restrict__ q_params,
int elems_per_group,
int total_elems)
{
dequantize::to_global<T, numBits, qType, unroll, threads>(
dequant_data, q_data, q_params, elems_per_group, total_elems);
}
#define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \
dequantize_kernel<T, num_bits, q_type, unroll, threads><<<grid, block, 0, stream>>>( \
dequant_data, q_data, q_params, elems_per_group, total_elems);
template <typename T>
void launch_dequantize_kernel(T* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream)
{
constexpr int unroll = 8;
constexpr int threads = 512;
constexpr int elems_per_block = unroll * threads * dequantize::granularity / (sizeof(T));
const dim3 block(threads);
const dim3 grid((total_elems + elems_per_block - 1) / elems_per_block);
// TODO(cmikeh2): It may make sense to tune unroll, there is perf benefit for large
// problem sizes with this large unroll value.
if (num_bits == 8 && q_type == quantize::Type::Symmetric) {
LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Symmetric);
} else if (num_bits == 8 && q_type == quantize::Type::Asymmetric) {
LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Asymmetric);
} else if (num_bits == 4 && q_type == quantize::Type::Symmetric) {
LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Symmetric);
} else if (num_bits == 4 && q_type == quantize::Type::Asymmetric) {
LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Asymmetric);
}
}
template void launch_dequantize_kernel(__half* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream);
template void launch_dequantize_kernel(float* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream);
/*
Copyright The Microsoft DeepSpeed Team
*/
#include <math.h>
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
cg::thread_block b = cg::this_thread_block(); // tb
cg::thread_block_tile<32> g =
cg::tiled_partition<32>(b); // warp, 32 not optimal for AMD which should be 64.
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
float2 data[MAX_REG];
__half data[vals_per_access];
int group_id = blockIdx.x;
{
int group_index = id;
int thread_index = id * vals_per_access;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) {
if (abs((float)data[i]) > max) max = abs((float)data[i]);
}
}
#pragma unroll
......@@ -61,39 +63,30 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale = (float)(1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
int q_range_max = (1 << (num_bits - 1)) - 1;
int q_range_min = -(1 << (num_bits - 1));
vals_cast[offset + group_index] = data[i];
}
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
#pragma unroll
for (int j = 0; j < vals_per_access; j++) {
float q_data;
q_data = __half2float(data[j]);
q_data = __float2int_rn(q_data * q_scale);
q_data = q_data > (q_range_max) ? (q_range_max)
: (q_data < (q_range_min) ? (q_range_min) : q_data);
data[j] = __float2half_rn(q_data * q_scale_inv);
}
mem_access::store_global<granularity>(vals + offset + thread_index, data);
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
......@@ -103,31 +96,31 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);
float4 data[MAX_REG];
float data[vals_per_access];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int thread_index = id * vals_per_access;
int reg_count = 0;
float max = -10000.0;
int offset = bid * group_size;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
float max = -10000.0;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) {
if (abs(data[i]) > max) max = abs(data[i]);
}
id = threadIdx.x;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
......@@ -153,30 +146,27 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
int q_range_max = (1 << (num_bits - 1)) - 1;
int q_range_min = -(1 << (num_bits - 1));
vals_cast[group_index + bid * group_size] = q_data;
for (int thread_index = id * vals_per_access; thread_index < group_size;
thread_index += blockDim.x * vals_per_access) {
mem_access::load_global<granularity>(data, vals + offset + thread_index);
#pragma unroll
for (int j = 0; j < vals_per_access; j++) {
float q_data;
q_data = __float2int_rn(data[j] * q_scale);
q_data = q_data > (q_range_max) ? (q_range_max)
: (q_data < (q_range_min) ? (q_range_min) : q_data);
data[j] = roundf(q_data * q_scale_inv);
}
mem_access::store_global<granularity>(vals + offset + thread_index, data);
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
void launch_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
......@@ -185,22 +175,22 @@ void launch_quantize_kernel(T* vals,
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, total_count / group_num, num_bits);
}
template void launch_quantize_kernel(float* vals,
template void launch_fake_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel(__half* vals,
template void launch_fake_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
__global__ void sr_fake_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
......@@ -336,7 +326,7 @@ __global__ void sr_quantize_kernel(__half* vals,
#endif
}
__global__ void sr_quantize_kernel(float* vals,
__global__ void sr_fake_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
......@@ -456,7 +446,7 @@ __global__ void sr_quantize_kernel(float* vals,
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
void launch_sr_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
......@@ -468,21 +458,21 @@ void launch_sr_quantize_kernel(T* vals,
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
template void launch_sr_fake_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
template void launch_sr_fake_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
......@@ -595,7 +585,7 @@ __global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
......@@ -699,7 +689,7 @@ __global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
void launch_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
......@@ -708,22 +698,22 @@ void launch_quantize_kernel_asym(T* vals,
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
fake_quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
template void launch_fake_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
template void launch_fake_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
__global__ void sr_fake_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
......@@ -879,7 +869,7 @@ __global__ void sr_quantize_kernel_asym(__half* vals,
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
__global__ void sr_fake_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
......@@ -1010,7 +1000,7 @@ __global__ void sr_quantize_kernel_asym(float* vals,
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
void launch_sr_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
......@@ -1022,15 +1012,15 @@ void launch_sr_quantize_kernel_asym(T* vals,
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
template void launch_sr_fake_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
template void launch_sr_fake_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
......
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include "custom_cuda_layers.h"
#include "quantization.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
......@@ -10,8 +11,8 @@ at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel(
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
......@@ -25,7 +26,7 @@ at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel(
launch_sr_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
......@@ -38,8 +39,8 @@ at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel_asym(
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
......@@ -53,12 +54,83 @@ at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel_asym(
launch_sr_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
int groups,
int numBits,
quantize::Type quantType)
{
auto dtype = at::kFloat;
auto params_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
auto params = torch::empty({groups, param_elems}, params_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output_sizes = input_vals.sizes().vec();
output_sizes[output_sizes.size() - 1] /= numBits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
launch_quant((int8_t*)output.data_ptr(),
(float*)params.data_ptr(),
(__half*)input_vals.data_ptr(),
groups,
elems_per_group,
numBits,
quantType,
at::cuda::getCurrentCUDAStream());
return {output, params};
}
template <typename T>
at::Tensor dequantize(at::Tensor& quantized_data,
at::Tensor& params,
int groups,
int num_bits,
quantize::Type quant_type)
{
auto dtype = (std::is_same<T, float>::value) ? torch::kFloat32 : torch::kFloat16;
auto output_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output_sizes = quantized_data.sizes().vec();
output_sizes[output_sizes.size() - 1] *= num_bits == 8 ? 1 : 2;
auto output = torch::empty(output_sizes, output_options);
const int total_elems = at::numel(output);
const int elems_per_group = total_elems / groups;
launch_dequantize_kernel((T*)output.data_ptr(),
(const int8_t*)quantized_data.data_ptr(),
(const float*)params.data_ptr(),
quant_type,
num_bits,
elems_per_group,
total_elems,
at::cuda::getCurrentCUDAStream());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
......@@ -74,4 +146,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
pybind11::enum_<quantize::Type>(m, "QuantizationType")
.value("Symmetric", quantize::Type::Symmetric)
.value("Asymmetric", quantize::Type::Asymmetric)
.export_values();
m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize<__half>);
m.def("dequantize_fp32", &dequantize<float>);
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "quantization.h"
#include "quantization_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
/*
Pure quantization kernel with no fusion.
*/
template <int q_bits,
quantize::Type quant_type,
int UNROLL,
int internal_unroll,
int threads_per_group,
int max_threads>
__global__ void cached_quantization(int8_t* __restrict__ output_data,
float* __restrict__ params,
const __half* __restrict__ input_data,
int groups,
int elems_per_group)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// Indexing offsets
const int block_offset =
(tb.group_index().x * (max_threads / threads_per_group) * elems_per_group) +
(tb.thread_index().y * elems_per_group);
const int elem_offset = tb.thread_index().x * quantize::h_per_load;
const int base_offset = block_offset + elem_offset;
const int stride = tb.size() * quantize::h_per_load;
const __half* input_base = input_data + base_offset; //..
__half2 local_buffer[UNROLL * internal_unroll * quantize::h2_per_load];
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
// Convenience helper, should resolve to register indices and not realize.
__half2* iteration_buffer = local_buffer + i * internal_unroll * quantize::h2_per_load;
#pragma unroll
for (int j = 0; j < internal_unroll; j++) {
const int iteration = i * internal_unroll + j;
mem_access::load_global<quantize::granularity>(
iteration_buffer + j * quantize::h2_per_load,
input_base + iteration * stride,
elem_offset + iteration * stride < elems_per_group);
}
}
quantize::
local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
local_buffer, params, output_data, elems_per_group, groups);
}
/********* Launcher methods ***********/
#define LAUNCH_CACHED_QUANT_CALL(q_bits, quant_type) \
cached_quantization<q_bits, \
quant_type, \
unroll_factor, \
internal_unroll_l, \
threads_per_group, \
max_threads> \
<<<grid, block, 0, stream>>>(output_data, params, input_data, groups, elems_per_group);
#define LAUNCH_CACHED_QUANT( \
q_bits, quant_type, unroll_factor_in, internal_unroll_in, threads_per_group_in) \
const int unroll_factor = unroll_factor_in; \
const int internal_unroll_l = internal_unroll_in; \
const int threads_per_group = threads_per_group_in; \
if (q_bits == 4) { \
if (quant_type == quantize::Type::Asymmetric) { \
LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Asymmetric) \
} else { \
LAUNCH_CACHED_QUANT_CALL(4, quantize::Type::Symmetric) \
} \
} else { \
if (quant_type == quantize::Type::Asymmetric) { \
LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Asymmetric) \
} else { \
LAUNCH_CACHED_QUANT_CALL(8, quantize::Type::Symmetric) \
} \
}
void launch_quant(int8_t* output_data,
float* params,
const __half* input_data,
const int groups,
const int elems_per_group,
const int num_bits,
const quantize::Type quant_type,
cudaStream_t stream)
{
constexpr int max_threads = 256;
constexpr int internal_unroll = 2;
const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? quantize::h_per_load
: quantize::h_per_load * internal_unroll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step);
const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;
const int groups_per_block =
is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;
dim3 block(threads_per_group, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threads_per_group * h_per_step;
const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threads_per_group == 1) {
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 1);
} else if (threads_per_group == 2) {
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 2);
} else if (threads_per_group == 4) {
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 4);
} else if (threads_per_group == 8) {
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 8);
} else if (threads_per_group == 16) {
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, 1, 16);
}
} else if (external_unroll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_CACHED_QUANT(num_bits, quant_type, 1, internal_unroll, max_threads);
} else if (external_unroll == 2) {
// 4097 - 8192 elems
LAUNCH_CACHED_QUANT(num_bits, quant_type, 2, internal_unroll, max_threads);
} else if (external_unroll == 3) {
// 8193 - 12288 elems
LAUNCH_CACHED_QUANT(num_bits, quant_type, 3, internal_unroll, max_threads);
} else if (external_unroll == 4) {
// 12289 - 16384 elems
LAUNCH_CACHED_QUANT(num_bits, quant_type, 4, internal_unroll, max_threads);
}
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
namespace td_data {
constexpr int granularity = 16;
}
template <typename T>
__global__ void gather_tokens_impl(T* retained_tokens,
const T* activations,
int32_t* gather_indices,
int32_t sampled_tokens,
int32_t channels,
int32_t read_batch_stride,
int32_t read_seq_stride,
int32_t write_batch_stride,
int32_t write_seq_stride)
{
constexpr int mem_vals_t = td_data::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
const int gather_idx = gather_indices[tb.group_index().x * sampled_tokens + tb.group_index().y];
const int read_offset = read_batch_stride * tb.group_index().x + read_seq_stride * gather_idx;
const int write_offset =
write_batch_stride * tb.group_index().x + write_seq_stride * tb.group_index().y;
for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += blockDim.x * mem_vals_t) {
T local_data[mem_vals_t];
mem_access::load_global<td_data::granularity>(local_data, activations + read_offset + i);
mem_access::store_global<td_data::granularity>(retained_tokens + write_offset + i,
local_data);
}
}
template <typename T>
void launch_gather_tokens(T* retained_tokens,
T* activations,
int32_t* gather_indices,
int32_t batch_size,
int32_t sampled_tokens,
int32_t channels,
int32_t read_batch_stride,
int32_t read_seq_stride,
int32_t write_batch_stride,
int32_t write_seq_stride,
cudaStream_t stream)
{
constexpr int mem_vals_t = td_data::granularity / sizeof(T);
const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
const int threads = (load_steps >= 1024) ? 1024 : load_steps;
dim3 block(threads);
dim3 grid(batch_size, sampled_tokens);
gather_tokens_impl<T><<<grid, block, 0, stream>>>(retained_tokens,
activations,
gather_indices,
sampled_tokens,
channels,
read_batch_stride,
read_seq_stride,
write_batch_stride,
write_seq_stride);
}
template void launch_gather_tokens<float>(float*,
float*,
int32_t*,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
cudaStream_t);
template void launch_gather_tokens<__half>(__half*,
__half*,
int32_t*,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
cudaStream_t);
template <typename T>
__global__ void scatter_tokens_impl(T* all_activations,
const T* layer_activations,
int32_t* gather_indices,
int32_t retained_tokens,
int32_t channels,
int32_t read_batch_stride,
int32_t read_seq_stride,
int32_t write_batch_stride,
int32_t write_seq_stride)
{
constexpr int mem_vals_t = td_data::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
const int gather_idx =
gather_indices[tb.group_index().x * retained_tokens + tb.group_index().y];
const int read_offset =
read_batch_stride * tb.group_index().x + read_seq_stride * tb.group_index().y;
const int write_offset =
write_batch_stride * tb.group_index().x + write_seq_stride * gather_idx;
for (int i = tb.thread_index().x * mem_vals_t; i < channels; i += mem_vals_t * blockDim.x) {
T local_data[mem_vals_t];
mem_access::load_global<td_data::granularity>(local_data,
layer_activations + read_offset + i);
mem_access::store_global<td_data::granularity>(all_activations + write_offset + i,
local_data);
}
}
template <typename T>
void launch_scatter_tokens(T* all_activations,
T* layer_activations,
int32_t* gather_indices,
int32_t batch_size,
int32_t sampled_tokens,
int32_t channels,
int32_t read_batch_stride,
int32_t read_seq_stride,
int32_t write_batch_stride,
int32_t write_seq_stride,
cudaStream_t stream)
{
constexpr int mem_vals_t = td_data::granularity / sizeof(T);
const int load_steps = (channels + mem_vals_t - 1) / mem_vals_t;
const int threads = (load_steps >= 1024) ? 1024 : load_steps;
dim3 block(threads);
dim3 grid(batch_size, sampled_tokens);
scatter_tokens_impl<T><<<grid, block, 0, stream>>>(all_activations,
layer_activations,
gather_indices,
sampled_tokens,
channels,
read_batch_stride,
read_seq_stride,
write_batch_stride,
write_seq_stride);
}
template void launch_scatter_tokens<float>(float*,
float*,
int32_t*,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
cudaStream_t);
template void launch_scatter_tokens<__half>(__half*,
__half*,
int32_t*,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
int32_t,
cudaStream_t);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <torch/extension.h>
#include <vector>
#include "custom_cuda_layers.h"
torch::Tensor token_sort_(torch::Tensor& unsorted_token_ids, int64_t original_tokens)
{
const int layers = unsorted_token_ids.size(0);
const int batch_size = unsorted_token_ids.size(1);
const int reserved_tokens = unsorted_token_ids.size(2);
launch_token_sort(unsorted_token_ids.data_ptr<int32_t>(),
layers,
batch_size,
reserved_tokens,
original_tokens,
c10::cuda::getCurrentCUDAStream());
return unsorted_token_ids;
}
torch::Tensor token_gather(torch::Tensor& activations,
torch::Tensor& sorted_indices,
bool batch_first)
{
// Activations may be in either [N, S, C] or [S, N, C] while sorted_indices is
// always in [N, retained]
/*
TORCH_CHECK(sorted_indices.size(0) == activations.size(0) ||
sorted_indices.size(0) == activations.size(1),
"Unable to match the batch size of the sorted indices to the activation
shape."); TORCH_CHECK(activations.size(2) % 8 == 0, "Channels must be divisible by 8 to align
with vectorized loads.");
*/
// bool batch_first = sorted_indices.size(0) == activations.size(0);
const int64_t dim_0 = (batch_first) ? sorted_indices.size(0) : sorted_indices.size(1);
const int64_t dim_1 = (batch_first) ? sorted_indices.size(1) : sorted_indices.size(0);
const int64_t dim_2 = activations.size(2);
auto output = torch::empty({dim_0, dim_1, dim_2}, activations.options());
const int batch_size = sorted_indices.size(0);
const int channels = dim_2;
const int retained_tokens = sorted_indices.size(1);
const int read_batch_stride = (batch_first) ? activations.stride(0) : activations.stride(1);
const int read_seq_stride = (batch_first) ? activations.stride(1) : activations.stride(0);
const int write_batch_stride = (batch_first) ? output.stride(0) : output.stride(1);
const int write_seq_stride = (batch_first) ? output.stride(1) : output.stride(0);
if (activations.options().dtype() == torch::kFloat) {
launch_gather_tokens((float*)output.data_ptr(),
(float*)activations.data_ptr(),
(int32_t*)sorted_indices.data_ptr(),
batch_size,
retained_tokens,
channels,
read_batch_stride,
read_seq_stride,
write_batch_stride,
write_seq_stride,
c10::cuda::getCurrentCUDAStream());
} else {
launch_gather_tokens((__half*)output.data_ptr(),
(__half*)activations.data_ptr(),
(int32_t*)sorted_indices.data_ptr(),
batch_size,
retained_tokens,
channels,
read_batch_stride,
read_seq_stride,
write_batch_stride,
write_seq_stride,
c10::cuda::getCurrentCUDAStream());
}
return output;
}
torch::Tensor token_scatter_(torch::Tensor& all_activations,
torch::Tensor& layer_activations,
torch::Tensor& sorted_indices,
bool batch_first)
{
// Activations may be in either [N, S, C] or [S, N, C] while sorted_indices is
// always in [N, retained]
/*
TORCH_CHECK(sorted_indices.size(0) == all_activations.size(0) ||
sorted_indices.size(0) == all_activations.size(1),
"Unable to match the batch size of the sorted indices to the activation
shape."); TORCH_CHECK(all_activations.size(2) % 8 != 0, "Channels must be divisible by 8 to
align with vectorized loads.");
*/
// bool batch_first = sorted_indices.size(0) == all_activations.size(0);
const int batch_size = sorted_indices.size(0);
const int channels = all_activations.size(2);
const int retained_tokens = sorted_indices.size(1);
const int read_batch_stride = (batch_first) ? layer_activations.stride(0)
: layer_activations.stride(1);
const int read_seq_stride = (batch_first) ? layer_activations.stride(1)
: layer_activations.stride(0);
const int write_batch_stride = (batch_first) ? all_activations.stride(0)
: all_activations.stride(1);
const int write_seq_stride = (batch_first) ? all_activations.stride(1)
: all_activations.stride(0);
if (all_activations.options().dtype() == torch::kFloat) {
launch_scatter_tokens((float*)all_activations.data_ptr(),
(float*)layer_activations.data_ptr(),
(int32_t*)sorted_indices.data_ptr(),
batch_size,
retained_tokens,
channels,
read_batch_stride,
read_seq_stride,
write_batch_stride,
write_seq_stride,
c10::cuda::getCurrentCUDAStream());
} else {
launch_scatter_tokens((__half*)all_activations.data_ptr(),
(__half*)layer_activations.data_ptr(),
(int32_t*)sorted_indices.data_ptr(),
batch_size,
retained_tokens,
channels,
read_batch_stride,
read_seq_stride,
write_batch_stride,
write_seq_stride,
c10::cuda::getCurrentCUDAStream());
}
return all_activations;
}
torch::Tensor mask_gather_bert(torch::Tensor& dense_mask, torch::Tensor& sorted_indices)
{
// TORCH_CHECK(dense_mask.dim() == 4)
const int batch_size = dense_mask.size(0);
const int layers = sorted_indices.size(0);
/*
TORCH_CHECK(layers * batch_size == sorted_indices.size(0),
"Mismatch between the indices and the mask");
*/
const int orig_seq_len = dense_mask.size(3);
const int truncated_seq_len = sorted_indices.size(2);
auto output = torch::empty({layers, batch_size, 1, truncated_seq_len, truncated_seq_len},
dense_mask.options());
if (dense_mask.options().dtype() == torch::kFloat) {
launch_slice_bert_mask((float*)output.data_ptr(),
(const float*)dense_mask.data_ptr(),
(const int32_t*)sorted_indices.data_ptr(),
layers,
batch_size,
truncated_seq_len,
orig_seq_len,
c10::cuda::getCurrentCUDAStream());
} else {
launch_slice_bert_mask((__half*)output.data_ptr(),
(const __half*)dense_mask.data_ptr(),
(const int32_t*)sorted_indices.data_ptr(),
layers,
batch_size,
truncated_seq_len,
orig_seq_len,
c10::cuda::getCurrentCUDAStream());
}
return output;
}
torch::Tensor mask_gather_gpt(torch::Tensor dense_mask, int truncated_seq_len)
{
// TORCH_CHECK(dense_mask.dim() == 4)
const int batch_size = dense_mask.size(0);
const int orig_seq_len = dense_mask.size(3);
auto output =
torch::empty({batch_size, 1, truncated_seq_len, truncated_seq_len}, dense_mask.options());
if (dense_mask.options().dtype() == torch::kFloat) {
launch_slice_gpt_mask((float*)output.data_ptr(),
(const float*)dense_mask.data_ptr(),
batch_size,
truncated_seq_len,
orig_seq_len,
c10::cuda::getCurrentCUDAStream());
} else {
launch_slice_gpt_mask((__half*)output.data_ptr(),
(const __half*)dense_mask.data_ptr(),
batch_size,
truncated_seq_len,
orig_seq_len,
c10::cuda::getCurrentCUDAStream());
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("token_sort_", &token_sort_, "Comparison free sorting algorithm (CUDA)");
m.def("token_gather", &token_gather, "Parallel gather of tokens (CUDA)");
m.def("token_scatter_", &token_scatter_, "Parallel scatter of tokens (CUDA)");
m.def("mask_gather_bert", &mask_gather_bert, "Token-based mask gather for BERT masking (CUDA)");
m.def("mask_gather_gpt", &mask_gather_gpt, "Token-based mask gather for GPT masking (CUDA)");
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void slice_gpt_mask_impl(T* output_mask,
const T* input_mask,
int truncated_seq_len,
int orig_seq_len)
{
const int in_batch_stride = orig_seq_len * orig_seq_len;
const int out_batch_stride = truncated_seq_len * truncated_seq_len;
cg::thread_block tb = cg::this_thread_block();
const T* input_mask_block =
input_mask + blockIdx.x * in_batch_stride + blockIdx.y * orig_seq_len;
T* output_mask_block =
output_mask + blockIdx.x * out_batch_stride + blockIdx.y * truncated_seq_len;
for (int i = tb.thread_index().x; i < truncated_seq_len; i += blockDim.x) {
output_mask_block[i] = input_mask_block[i];
}
}
template <typename T>
void launch_slice_gpt_mask(T* output_mask,
const T* input_mask,
int batch_size,
int truncated_seq_len,
int orig_seq_len,
cudaStream_t stream)
{
const int threads = (truncated_seq_len >= 1024) ? 1024 : truncated_seq_len;
dim3 block(threads);
dim3 grid(batch_size, truncated_seq_len);
slice_gpt_mask_impl<T>
<<<grid, block, 0, stream>>>(output_mask, input_mask, truncated_seq_len, orig_seq_len);
}
template void launch_slice_gpt_mask<float>(float*, const float*, int, int, int, cudaStream_t);
template void launch_slice_gpt_mask<__half>(__half*, const __half*, int, int, int, cudaStream_t);
template <typename T>
__global__ void slice_bert_mask_impl(T* output_mask,
const T* input_mask,
const int32_t* retained_indices,
int32_t truncated_seq_len,
int32_t orig_seq_len)
{
const int in_batch_stride = orig_seq_len * orig_seq_len;
const int out_batch_stride = truncated_seq_len * truncated_seq_len;
const int out_layer_stride = out_batch_stride * gridDim.y;
cg::thread_block tb = cg::this_thread_block();
const int out_layer_offset = tb.group_index().x * out_layer_stride;
const int in_batch_offset = tb.group_index().y * in_batch_stride;
const int out_batch_offset = tb.group_index().y * out_batch_stride;
const int32_t gather_row =
retained_indices[tb.group_index().y * truncated_seq_len + tb.group_index().z];
const int in_seq_offset = gather_row * orig_seq_len;
const int out_seq_offset = tb.group_index().z * truncated_seq_len;
const T* in_sequence = input_mask + in_batch_offset + in_seq_offset;
T* out_sequence = output_mask + out_layer_offset + out_batch_offset + out_seq_offset;
const int32_t* gather_data = retained_indices + tb.group_index().y * truncated_seq_len;
for (int i = tb.thread_index().x; i < truncated_seq_len; i += blockDim.x) {
out_sequence[i] = in_sequence[gather_data[i]];
}
}
/*
Since the Bert mask is not causal like GPT, we can't just generate a set of
masks for the entire model based off a single layer sample.
We map the kernel as follows:
z-dimension: layer
y-dimension: batch
x-dimension: sequence_offset
*/
template <typename T>
void launch_slice_bert_mask(T* output_mask,
const T* input_mask,
const int32_t* retained_indices,
int32_t layers,
int32_t batch_size,
int32_t truncated_seq_len,
int32_t orig_seq_len,
cudaStream_t stream)
{
const int threads = (truncated_seq_len >= 1024) ? 1024 : truncated_seq_len;
dim3 block(threads);
dim3 grid(layers, batch_size, truncated_seq_len);
slice_bert_mask_impl<T><<<grid, block, 0, stream>>>(
output_mask, input_mask, retained_indices, truncated_seq_len, orig_seq_len);
}
template void launch_slice_bert_mask<float>(float*,
const float*,
const int32_t*,
int32_t,
int32_t,
int32_t,
int32_t,
cudaStream_t);
template void launch_slice_bert_mask<__half>(__half*,
const __half*,
const int32_t*,
int32_t,
int32_t,
int32_t,
int32_t,
cudaStream_t);
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <cassert>
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
namespace td_sort {
constexpr int threads = 512;
constexpr int granularity = 16;
constexpr int mem_vals = granularity / sizeof(int32_t);
constexpr int max_buffer_size = (threads + 1) * mem_vals;
#ifdef __HIP_PLATFORM_HCC__
constexpr int warp_size = 64;
#else
constexpr int warp_size = 32;
#endif
constexpr int max_warps = threads / warp_size;
} // namespace td_sort
template <int VALS_PER_THREAD>
__global__ void scan_sort(int32_t* data, int reserved_tokens, int original_tokens)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<td_sort::warp_size> warp = cg::tiled_partition<td_sort::warp_size>(tb);
__shared__ int32_t indices_buffer[td_sort::max_buffer_size];
__shared__ int32_t intermediate_buffer[td_sort::max_warps];
__shared__ int32_t sorted_indices_buffer[td_sort::max_buffer_size];
for (int i = tb.thread_index().x * td_sort::mem_vals; i < original_tokens + 1;
i += tb.group_dim().x * td_sort::mem_vals) {
uint32_t zeros[td_sort::mem_vals] = {0, 0, 0, 0};
mem_access::store_shared<td_sort::granularity>(indices_buffer + i, zeros);
}
int32_t local_vals[VALS_PER_THREAD];
// We flatten layers/batch into a single indexing dimension
int32_t* data_block = data + tb.group_index().x * reserved_tokens;
// The next two loops really could be fused for a more logical code layout, but don't want to
// move the barrier forward
#pragma unroll
for (int i = 0; i < VALS_PER_THREAD; i++) {
const int iter_idx = i * td_sort::threads + tb.thread_index().x;
if (iter_idx < reserved_tokens) {
mem_access::load_global<sizeof(int32_t)>(local_vals + i, data_block + iter_idx);
} else {
local_vals[i] = 0;
}
}
tb.sync();
#pragma unroll
for (int i = 0; i < VALS_PER_THREAD; i++) {
const int iter_idx = i * td_sort::threads + tb.thread_index().x;
if (iter_idx < reserved_tokens) {
const int32_t one = 1;
mem_access::store_shared<sizeof(int32_t)>(indices_buffer + local_vals[i], &one);
}
}
tb.sync();
int32_t local_input[td_sort::mem_vals];
mem_access::load_shared<td_sort::granularity>(
local_input, indices_buffer + tb.thread_index().x * td_sort::mem_vals);
int32_t reduce_vals[td_sort::mem_vals];
reduce_vals[0] = local_input[0];
#pragma unroll
for (int i = 1; i < td_sort::mem_vals; i++) {
reduce_vals[i] = local_input[i] + reduce_vals[i - 1];
}
int32_t step_1_val = reduce_vals[td_sort::mem_vals - 1];
// Short span exclusive scan algorithm (less work efficient)
#pragma unroll
for (int i = 1; i < td_sort::warp_size; i *= 2) {
int32_t step_val = warp.shfl_up(step_1_val, i);
step_1_val = (warp.thread_rank() < i) ? step_1_val : step_1_val + step_val;
}
if (warp.thread_rank() == td_sort::warp_size - 1) {
mem_access::store_shared<sizeof(int32_t)>(intermediate_buffer + warp.meta_group_rank(),
&step_1_val);
}
tb.sync();
if (warp.meta_group_rank() == 0) {
int32_t step_2_val = 0;
if (warp.thread_rank() < td_sort::max_warps) {
mem_access::load_shared<sizeof(int32_t)>(&step_2_val,
intermediate_buffer + warp.thread_rank());
}
#pragma unroll
for (int i = 1; i < td_sort::warp_size; i *= 2) {
int32_t step_val = warp.shfl_up(step_2_val, i);
step_2_val = (warp.thread_rank() < i) ? step_2_val : step_2_val + step_val;
}
if (warp.thread_rank() < td_sort::max_warps) {
mem_access::store_shared<sizeof(int32_t)>(intermediate_buffer + warp.thread_rank(),
&step_2_val);
}
}
tb.sync();
int step_2_val = 0;
if (warp.meta_group_rank() > 0) {
mem_access::load_shared<sizeof(int32_t)>(&step_2_val,
intermediate_buffer + warp.meta_group_rank() - 1);
}
const int thread_offset = reduce_vals[td_sort::mem_vals - 1];
#pragma unroll
for (int i = 0; i < td_sort::mem_vals; i++) {
reduce_vals[i] += step_1_val + step_2_val - thread_offset;
}
mem_access::store_shared<td_sort::granularity>(
indices_buffer + tb.thread_index().x * td_sort::mem_vals, reduce_vals);
if (tb.thread_index().x == 0) {
indices_buffer[original_tokens] = original_tokens - indices_buffer[original_tokens];
}
tb.sync();
for (int i = 0; i < VALS_PER_THREAD; i++) {
const int iter_idx = i * td_sort::threads + tb.thread_index().x;
if (iter_idx < reserved_tokens) {
if (local_vals[i] == 0) {
int zero = 0;
mem_access::store_shared<sizeof(int32_t)>(sorted_indices_buffer, &zero);
} else {
int sorted_idx;
mem_access::load_shared<sizeof(int32_t)>(&sorted_idx,
indices_buffer + local_vals[i] - 1);
mem_access::store_shared<sizeof(int32_t)>(sorted_indices_buffer + sorted_idx,
local_vals + i);
}
}
}
tb.sync();
#pragma unroll
for (int i = 0; i < VALS_PER_THREAD; i++) {
const int iter_idx = i * td_sort::threads + tb.thread_index().x;
if (iter_idx < reserved_tokens) {
int32_t store_val;
mem_access::load_shared<sizeof(int32_t)>(&store_val, sorted_indices_buffer + iter_idx);
mem_access::store_global<sizeof(int32_t)>(data_block + iter_idx, &store_val);
}
}
}
void launch_token_sort(int32_t* indices,
int layers,
int batch_size,
int reserved_size,
int original_tokens,
cudaStream_t stream)
{
// Each sort is completely independent, can flatten this dimension
dim3 grid(layers * batch_size);
dim3 block(td_sort::threads);
const int vals_per_thread = (reserved_size + td_sort::threads - 1) / td_sort::threads;
if (vals_per_thread == 1) {
scan_sort<1><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
} else if (vals_per_thread == 2) {
scan_sort<2><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
} else if (vals_per_thread == 3) {
scan_sort<3><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
} else if (vals_per_thread == 4) {
scan_sort<4><<<grid, block, 0, stream>>>(indices, reserved_size, original_tokens);
} else {
assert(false);
}
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <cassert>
#include "memory_access_utils.h"
#include "spatial_cuda_layers.h"
/*
Fused bias add variants
*/
namespace badd_opt {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
constexpr int vals_per_h = granularity / sizeof(__half);
constexpr int vals_per_h2 = granularity / sizeof(__half2);
constexpr int vals_per_block = threads * steps * vals_per_h;
constexpr int stride = vals_per_h * threads;
} // namespace badd_opt
__global__ void opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;
for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
for (int j = 0; j < badd_opt::vals_per_h2; j++) { act_buffer[j] += bias_buffer[j]; }
mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}
__global__ void opt_bias_add_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;
for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
__half2 other_buffer[badd_opt::vals_per_h2];
mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
mem_access::load_global<badd_opt::granularity>(other_buffer, other + id + i * stride);
for (int j = 0; j < badd_opt::vals_per_h2; j++) {
act_buffer[j] += bias_buffer[j] + other_buffer[j];
}
mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}
__global__ void opt_bias_add_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int seq_len,
int channels)
{
const int id = blockIdx.x * badd_opt::vals_per_block + threadIdx.x * badd_opt::vals_per_h;
const int stride = badd_opt::vals_per_h * badd_opt::threads;
for (int i = 0; i < badd_opt::steps; i++) {
if (id + i * badd_opt::stride < seq_len * channels) {
__half2 act_buffer[badd_opt::vals_per_h2];
__half2 bias_buffer[badd_opt::vals_per_h2];
__half2 other_buffer[badd_opt::vals_per_h2];
__half2 other_bias_buffer[badd_opt::vals_per_h2];
mem_access::load_global<badd_opt::granularity>(act_buffer,
activation + id + i * stride);
mem_access::load_global<badd_opt::granularity>(bias_buffer,
bias + ((id + i * stride) % channels));
mem_access::load_global<badd_opt::granularity>(other_buffer, other + id + i * stride);
mem_access::load_global<badd_opt::granularity>(
other_bias_buffer, other_bias + ((id + i * stride) % channels));
for (int j = 0; j < badd_opt::vals_per_h2; j++) {
act_buffer[j] =
(act_buffer[j] + bias_buffer[j]) + (other_buffer[j] + other_bias_buffer[j]);
}
mem_access::store_global<badd_opt::granularity>(result + id + i * stride, act_buffer);
}
}
}
void launch_opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int batch_size,
int seq_len,
int channels,
cudaStream_t stream)
{
// Should evaluate `true` for reasonable hidden sizes
assert(channels % badd_opt::vals_per_h == 0);
const int effective_seq_len = batch_size * seq_len;
const int vals = effective_seq_len * channels;
dim3 block(badd_opt::threads);
dim3 grid((vals + badd_opt::vals_per_block - 1) / badd_opt::vals_per_block);
if (!other) {
// We shouldn't have a bias if there's no activation
assert(!other_bias);
opt_bias_add<<<grid, block, 0, stream>>>(
result, activation, bias, effective_seq_len, channels);
} else if (!other_bias) {
opt_bias_add_add<<<grid, block, 0, stream>>>(
result, activation, bias, other, effective_seq_len, channels);
} else {
opt_bias_add_bias_add<<<grid, block, 0, stream>>>(
result, activation, bias, other, other_bias, effective_seq_len, channels);
}
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <cstdio>
#include <vector>
#include "spatial_cuda_layers.h"
ChannelsLastProblem dimension_problem(at::Tensor& input)
{
ChannelsLastProblem dims;
if (input.dim() == 4) {
// In some sense this is unsafe (and a reflection of the assumptions made inside
// the C10 options checker). Basically, there's no great way to be sure that
// a tensor is in channels last because a 1x1 image will appear to be in channels
// last even when it isn't.
assert(input.is_contiguous(at::MemoryFormat::ChannelsLast));
dims.batch_size = input.size(0);
dims.seq_len = input.size(2) * input.size(3);
dims.channels = input.size(1);
} else {
assert(input.is_contiguous());
dims.batch_size = input.size(0);
dims.seq_len = input.size(1);
dims.channels = input.size(2);
}
return dims;
}
at::Tensor seq_unroll_bias_add(at::Tensor& input, at::Tensor& bias)
{
assert(input.dtype() == at::kHalf);
// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);
auto output = at::empty_like(input);
launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
nullptr,
nullptr,
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor seq_bias_add_add(at::Tensor& input, at::Tensor& bias, at::Tensor& other)
{
assert(input.dtype() == at::kHalf);
// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);
auto output = at::empty_like(input);
launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)other.data_ptr(),
nullptr,
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor seq_bias_add_bias_add(at::Tensor& input,
at::Tensor& bias,
at::Tensor& other,
at::Tensor& other_bias)
{
assert(input.dtype() == at::kHalf);
// TODO(cmikeh2): Should probably refactor this into a more portable
// description, since it does generalize for channels-last
ChannelsLastProblem problem = dimension_problem(input);
auto output = at::empty_like(input);
launch_opt_bias_add((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)other.data_ptr(),
(const __half*)other_bias.data_ptr(),
problem.batch_size,
problem.seq_len,
problem.channels,
at::cuda::getCurrentCUDAStream());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("nhwc_bias_add", &seq_unroll_bias_add);
m.def("nhwc_bias_add_add", &seq_bias_add_add);
m.def("nhwc_bias_add_bias_add", &seq_bias_add_bias_add);
}
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#pragma once
#if __CUDA_ARCH__ >= 530
#define HALF_PRECISION_AVAILABLE = 1
#endif
#ifdef __HIPCC__
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda.h>
#include <cuda_fp16.h>
/*********** Group Norm Kernels, Structs, and Helpers ************/
struct {
int64_t batch_size;
int64_t seq_len;
int64_t channels;
} typedef ChannelsLastProblem;
void launch_opt_bias_add(__half* result,
const __half* activation,
const __half* bias,
const __half* other,
const __half* other_bias,
int batch_size,
int seq_len,
int channels,
cudaStream_t stream);
/*
Copyright The Microsoft DeepSpeed Team
*/
#include "cublas_wrappers.h"
#ifdef __HIP_PLATFORM_HCC__
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#include "custom_cuda_layers.h"
const int unroll_factor = 4;
......
......@@ -113,7 +113,6 @@ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
_seq_length,
_hidden_size / _heads,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
//aiss debug 0506
(T(1.0 / (sqrt(_hidden_size / _heads)))),
T(0.0),
CUBLAS_OP_T,
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#include "custom_cuda_layers.h"
inline __device__ float gelu(const float x)
......
/*
Copyright The Microsoft DeepSpeed Team
*/
#include "general_kernels.h"
namespace cg = cooperative_groups;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment