Commit c25a91b6 authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.9.2-rocm' into 'main'

Ds v0.9.2 rocm

See merge request dcutoolkit/deeplearing/deepspeed!2
parents d1596c94 af82b300
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <cassert> #include <cassert>
#include "conversion_utils.h" #include "conversion_utils.h"
...@@ -101,9 +102,9 @@ public: ...@@ -101,9 +102,9 @@ public:
if (max == min) { if (max == min) {
scale = 1.0; scale = 1.0;
} else { } else {
scale = (1 << numBits) / (max - min); scale = ((1 << numBits)) / (max - min);
} }
offset = -(1 << (numBits - 1)) - (min * scale); offset = (max + min) / 2;
} }
DS_D_INLINE int8_t quantize(__half val) DS_D_INLINE int8_t quantize(__half val)
...@@ -111,7 +112,7 @@ public: ...@@ -111,7 +112,7 @@ public:
constexpr int32_t q_min = -(1 << (numBits - 1)); constexpr int32_t q_min = -(1 << (numBits - 1));
constexpr int32_t q_max = (1 << (numBits - 1)) - 1; constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
float val_f = conversion::to<float>(val) * scale + offset; float val_f = (conversion::to<float>(val) - offset) * scale;
int32_t data_i32 = conversion::to<int32_t>(val_f); int32_t data_i32 = conversion::to<int32_t>(val_f);
data_i32 = min(max(data_i32, q_min), q_max); data_i32 = min(max(data_i32, q_min), q_max);
return (int8_t)data_i32; return (int8_t)data_i32;
...@@ -120,7 +121,7 @@ public: ...@@ -120,7 +121,7 @@ public:
template <typename T> template <typename T>
DS_D_INLINE T dequantize(int8_t val) DS_D_INLINE T dequantize(int8_t val)
{ {
const float val_deq_f = conversion::to<float>(val) * scale + offset; const float val_deq_f = ((conversion::to<float>(val)) * scale) + offset;
return conversion::to<__half>(val_deq_f); return conversion::to<__half>(val_deq_f);
} }
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
#ifdef __HIP_PLATFORM_HCC__
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h> #include <cooperative_groups.h>
#endif
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#pragma once #pragma once
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */ /* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h> #include <ATen/ATen.h>
......
/* Copyright 2019 The Microsoft DeepSpeed Team */ // Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
......
/* Copyright 2019 The Microsoft DeepSpeed Team */ // Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h> #include <stdio.h>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "dequantization_utils.h" #include "dequantization_utils.h"
#include "memory_access_utils.h" #include "memory_access_utils.h"
......
/* // Copyright (c) Microsoft Corporation.
Copyright The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include <math.h> #include <math.h>
#include "custom_cuda_layers.h" #include "custom_cuda_layers.h"
...@@ -456,7 +457,7 @@ void launch_sr_fake_quantize_kernel(T* vals, ...@@ -456,7 +457,7 @@ void launch_sr_fake_quantize_kernel(T* vals,
dim3 grid_dim(group_num); dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x; uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>( sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed); vals, (total_count / group_num) / 4, group_num, num_bits, seed);
...@@ -1010,7 +1011,7 @@ void launch_sr_fake_quantize_kernel_asym(T* vals, ...@@ -1010,7 +1011,7 @@ void launch_sr_fake_quantize_kernel_asym(T* vals,
dim3 grid_dim(group_num); dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x; uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc); std::pair<uint64_t, uint64_t> seed = TrainingContext::Instance().IncrementOffset(inc);
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>( sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed); vals, (total_count / group_num) / 4, group_num, num_bits, seed);
......
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0
// DeepSpeed Team
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cassert> #include <cassert>
......
/* // Copyright (c) Microsoft Corporation.
Copyright 2022 The Microsoft DeepSpeed Team // SPDX-License-Identifier: Apache-2.0
*/
// DeepSpeed Team
#include "ds_kernel_utils.h" #include "ds_kernel_utils.h"
#include "memory_access_utils.h" #include "memory_access_utils.h"
......
#include <math.h>
#include "custom_cuda_layers.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
b.sync();
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
q_data_int[0].x =
(rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
: q_data_int[0].x;
q_data_int[0].y =
(rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
: q_data_int[0].y;
q_data_int[1].x =
(rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
: q_data_int[1].x;
q_data_int[1].y =
(rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x / q_scale_val;
data_f[0].y = q_data_int[0].y / q_scale_val;
data_f[1].x = q_data_int[1].x / q_scale_val;
data_f[1].y = q_data_int[1].y / q_scale_val;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
data[reg_count] = vals_cast[group_index];
if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)(q_data.x * q_scale_val));
q_data_int.y = (float)((int)(q_data.y * q_scale_val));
q_data_int.w = (float)((int)(q_data.w * q_scale_val));
q_data_int.z = (float)((int)(q_data.z * q_scale_val));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
q_data_int.x =
(rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
: q_data_int.x;
q_data_int.y =
(rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
: q_data_int.y;
q_data_int.w =
(rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
: q_data_int.w;
q_data_int.z =
(rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
: q_data_int.z;
q_data_int.x /= q_scale_val;
q_data_int.y /= q_scale_val;
q_data_int.w /= q_scale_val;
q_data_int.z /= q_scale_val;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
float min = 10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (((float)data_h[0]) > max) max = (float)data_h[0];
if (((float)data_h[1]) > max) max = (float)data_h[1];
if (((float)data_h[2]) > max) max = (float)data_h[2];
if (((float)data_h[3]) > max) max = (float)data_h[3];
if (((float)data_h[0]) < min) min = (float)data_h[0];
if (((float)data_h[1]) < min) min = (float)data_h[1];
if (((float)data_h[2]) < min) min = (float)data_h[2];
if (((float)data_h[3]) < min) min = (float)data_h[3];
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
q_data_int[0].x = q_data_int[0].x * q_scale + min;
q_data_int[0].y = q_data_int[0].y * q_scale + min;
q_data_int[1].x = q_data_int[1].x * q_scale + min;
q_data_int[1].y = q_data_int[1].y * q_scale + min;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
float min = 10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
q_data.x = q_data_int.x * q_scale + min;
q_data.y = q_data_int.y * q_scale + min;
q_data.w = q_data_int.w * q_scale + min;
q_data.z = q_data_int.z * q_scale + min;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_val_inv = 1 / q_scale_val;
float high_q = (float)((1 << num_bits) - 1);
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] =
abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[1] =
abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
q_error[2] =
abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[3] =
abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
? (q_data_int[0].x + 1)
: q_data_int[0].x;
q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
? (q_data_int[0].y + 1)
: q_data_int[0].y;
q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
? (q_data_int[1].x + 1)
: q_data_int[1].x;
q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
? (q_data_int[1].y + 1)
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x * q_scale_val + min;
data_f[0].y = q_data_int[0].y * q_scale_val + min;
data_f[1].x = q_data_int[1].x * q_scale_val + min;
data_f[1].y = q_data_int[1].y * q_scale_val + min;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
curandStatePhilox4_32_10_t state;
curand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float high_q = (float)((1 << num_bits) - 1);
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
// Stochastic rounding
float4 rand = curand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
: q_data_int.x;
q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
: q_data_int.y;
q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
: q_data_int.w;
q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
: q_data_int.z;
q_data_int.x = q_data_int.x * q_scale_val + min;
q_data_int.y = q_data_int.y * q_scale_val + min;
q_data_int.w = q_data_int.w * q_scale_val + min;
q_data_int.z = q_data_int.z * q_scale_val + min;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment