Unverified Commit a99c056b authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fixed integer overflow issue in cast kernels (#1988)



* Fixed integer overflow when computing offsets
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 71b2dd48
...@@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations // Cache computations
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j; const size_t idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]); float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) { if (processing_method == ProcessingMethod::CAST_DBIAS) {
...@@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f; float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
} }
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal()); const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = i * scales_stride_rowwise + tile_X; const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent; output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent); const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j; const size_t idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal); output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
} }
} }
...@@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f; float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
} }
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal()); const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = tile_Y * scales_stride_colwise + j; const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent; output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent); const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int idx = i * cols + j; const size_t idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal); output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
} }
} }
......
...@@ -64,7 +64,7 @@ void compute_ref(const IType* grad, ...@@ -64,7 +64,7 @@ void compute_ref(const IType* grad,
float silu_elt = static_cast<float>(input[i * stride + j]); float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]); float gate_elt = static_cast<float>(input[i * stride + cols + j]);
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) { if (IS_DGATED) {
const float x = silu_elt; const float x = silu_elt;
...@@ -101,7 +101,7 @@ void compute_ref(const IType* grad, ...@@ -101,7 +101,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f; float block_amax_act = 0.0f;
float block_amax_gate = 0.0f; float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) { if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
...@@ -109,18 +109,18 @@ void compute_ref(const IType* grad, ...@@ -109,18 +109,18 @@ void compute_ref(const IType* grad,
} }
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = i * scales_stride_rowwise + tile_X; const size_t scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act; output_scales_rowwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate; float scale_reciprocal_gate;
if (IS_DGATED) { if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
} }
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) { if (IS_DGATED) {
...@@ -139,7 +139,7 @@ void compute_ref(const IType* grad, ...@@ -139,7 +139,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f; float block_amax_act = 0.0f;
float block_amax_gate = 0.0f; float block_amax_gate = 0.0f;
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) { if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
...@@ -147,18 +147,18 @@ void compute_ref(const IType* grad, ...@@ -147,18 +147,18 @@ void compute_ref(const IType* grad,
} }
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = tile_Y * scales_stride_colwise + j; const size_t scale_idx_act = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx_act] = biased_exponent_act; output_scales_colwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate; float scale_reciprocal_gate;
if (IS_DGATED) { if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const int scale_idx_gate = scale_idx_act + cols; const size_t scale_idx_gate = scale_idx_act + cols;
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output_scales_colwise[scale_idx_gate] = biased_exponent_gate; output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
} }
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) { if (IS_DGATED) {
......
...@@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float *const scale_ptr, const size_t rows, const size_t cols) { const float *const scale_ptr, const size_t rows, const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y; const size_t thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X; const size_t thread_offset_X = tid_X;
float amax = 0; float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
...@@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int it = 0; it < ITERATIONS; ++it) { for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM; const size_t buff = it % BUFFERS_NUM;
const int next_it = it + 1; const size_t next_it = it + 1;
if (next_it < ITERATIONS) { if (next_it < ITERATIONS) {
const int next_buff = next_it % BUFFERS_NUM; const size_t next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
copy_2d_to_sharedx3( copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
...@@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X; const size_t shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]); float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]); float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
...@@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU // dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
...@@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
// # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X; const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const int thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise; const size_t thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; const size_t thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise; const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise; const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const int gate_scale_idx_offset_colwise = cols; const size_t gate_scale_idx_offset_colwise = cols;
// helps resolving bank conflicts in shmem // helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK; const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
__shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dynamic_shmem[]; extern __shared__ char dynamic_shmem[];
...@@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < STAGES; ++stage) { for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM; const size_t buff = stage % BUFFS_NUM;
const int next_stage = stage + 1; const size_t next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y; const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) { if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory. // Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to // I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>(); ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM; const size_t next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X; const size_t global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM; const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
...@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mbarrier_wait_parity(&mbar[stage], parity); ptx::mbarrier_wait_parity(&mbar[stage], parity);
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = const size_t shmem_offset_base_colwise =
buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
float thread_amax_act = 0.0f; float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f; float thread_amax_gate = 0.0f;
...@@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 1. Read/Compute elements. Find MXFP8-block AMAX // 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll #pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_colwise = const size_t shmem_offset_colwise =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]); float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
...@@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_act = const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise; const size_t global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
...@@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const int scale_idx_gate = scale_idx + scale_stride_colwise / 2; // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate; scales_colwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 3. Scale elements // 3. Scale elements
#pragma unroll #pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_elt = const size_t shmem_offset_elt =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
OType2 out_pair; OType2 out_pair;
...@@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float thread_amax_act = 0.0f; float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f; float thread_amax_gate = 0.0f;
...@@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)}; IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
...@@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else { } else {
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in_grad; Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act; Vec<IType, PACK_SIZE> in_act;
...@@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act = const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise; const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
...@@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate; scales_rowwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
} }
} }
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
...@@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X; const size_t global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM; const size_t buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment