"...git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "b96284adfa782319e063d5ce23b61a656cbef3ed"
Unverified Commit a99c056b authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fixed integer overflow issue in cast kernels (#1988)



* Fixed integer overflow when computing offsets
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 71b2dd48
...@@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method,
// Cache computations // Cache computations
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j; const size_t idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
float elt = static_cast<float>(input[idx]); float elt = static_cast<float>(input[idx]);
if (processing_method == ProcessingMethod::CAST_DBIAS) { if (processing_method == ProcessingMethod::CAST_DBIAS) {
...@@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f; float block_amax = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
} }
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal()); const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = i * scales_stride_rowwise + tile_X; const size_t scale_idx = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx] = biased_exponent; output_scales_rowwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent); const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int idx = i * cols + j; const size_t idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal); output_rowwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
} }
} }
...@@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method, ...@@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method,
float block_amax = 0.0f; float block_amax = 0.0f;
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx]));
} }
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal()); const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OutputType>::max_reciprocal());
const int scale_idx = tile_Y * scales_stride_colwise + j; const size_t scale_idx = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx] = biased_exponent; output_scales_colwise[scale_idx] = biased_exponent;
const float scale_reciprocal = exp2f_rcp(biased_exponent); const float scale_reciprocal = exp2f_rcp(biased_exponent);
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int idx = i * cols + j; const size_t idx = i * cols + j;
const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min);
output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal); output_colwise[idx] = static_cast<OutputType>(cache_buffer[cache_idx] * scale_reciprocal);
} }
} }
......
...@@ -64,7 +64,7 @@ void compute_ref(const IType* grad, ...@@ -64,7 +64,7 @@ void compute_ref(const IType* grad,
float silu_elt = static_cast<float>(input[i * stride + j]); float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]); float gate_elt = static_cast<float>(input[i * stride + cols + j]);
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
if (IS_DGATED) { if (IS_DGATED) {
const float x = silu_elt; const float x = silu_elt;
...@@ -101,7 +101,7 @@ void compute_ref(const IType* grad, ...@@ -101,7 +101,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f; float block_amax_act = 0.0f;
float block_amax_gate = 0.0f; float block_amax_gate = 0.0f;
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) { if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
...@@ -109,18 +109,18 @@ void compute_ref(const IType* grad, ...@@ -109,18 +109,18 @@ void compute_ref(const IType* grad,
} }
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = i * scales_stride_rowwise + tile_X; const size_t scale_idx_act = i * scales_stride_rowwise + tile_X;
output_scales_rowwise[scale_idx_act] = biased_exponent_act; output_scales_rowwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate; float scale_reciprocal_gate;
if (IS_DGATED) { if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32;
output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; output_scales_rowwise[scale_idx_gate] = biased_exponent_gate;
} }
for (size_t j = j_min; j < j_max; ++j) { for (size_t j = j_min; j < j_max; ++j) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) { if (IS_DGATED) {
...@@ -139,7 +139,7 @@ void compute_ref(const IType* grad, ...@@ -139,7 +139,7 @@ void compute_ref(const IType* grad,
float block_amax_act = 0.0f; float block_amax_act = 0.0f;
float block_amax_gate = 0.0f; float block_amax_gate = 0.0f;
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx]));
if (IS_DGATED) { if (IS_DGATED) {
block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx]));
...@@ -147,18 +147,18 @@ void compute_ref(const IType* grad, ...@@ -147,18 +147,18 @@ void compute_ref(const IType* grad,
} }
const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act);
const int scale_idx_act = tile_Y * scales_stride_colwise + j; const size_t scale_idx_act = tile_Y * scales_stride_colwise + j;
output_scales_colwise[scale_idx_act] = biased_exponent_act; output_scales_colwise[scale_idx_act] = biased_exponent_act;
float scale_reciprocal_gate; float scale_reciprocal_gate;
if (IS_DGATED) { if (IS_DGATED) {
const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal()); const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits<OType>::max_reciprocal());
const int scale_idx_gate = scale_idx_act + cols; const size_t scale_idx_gate = scale_idx_act + cols;
scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate);
output_scales_colwise[scale_idx_gate] = biased_exponent_gate; output_scales_colwise[scale_idx_gate] = biased_exponent_gate;
} }
for (size_t i = i_min; i < i_max; ++i) { for (size_t i = i_min; i < i_max; ++i) {
const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min);
const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act;
if (IS_DGATED) { if (IS_DGATED) {
......
...@@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float *const scale_ptr, const size_t rows, const size_t cols) { const float *const scale_ptr, const size_t rows, const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y; const size_t thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X; const size_t thread_offset_X = tid_X;
float amax = 0; float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
...@@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int it = 0; it < ITERATIONS; ++it) { for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM; const size_t buff = it % BUFFERS_NUM;
const int next_it = it + 1; const size_t next_it = it + 1;
if (next_it < ITERATIONS) { if (next_it < ITERATIONS) {
const int next_buff = next_it % BUFFERS_NUM; const size_t next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
copy_2d_to_sharedx3( copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
...@@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X; const size_t shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]); float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]); float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
...@@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU // dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
...@@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
// # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X; const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const int thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise; const size_t thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; const size_t thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise; const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise; const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const int gate_scale_idx_offset_colwise = cols; const size_t gate_scale_idx_offset_colwise = cols;
// helps resolving bank conflicts in shmem // helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK; const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
__shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dynamic_shmem[]; extern __shared__ char dynamic_shmem[];
...@@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < STAGES; ++stage) { for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM; const size_t buff = stage % BUFFS_NUM;
const int next_stage = stage + 1; const size_t next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y; const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) { if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory. // Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to // I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>(); ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM; const size_t next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X; const size_t global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM; const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
...@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mbarrier_wait_parity(&mbar[stage], parity); ptx::mbarrier_wait_parity(&mbar[stage], parity);
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = const size_t shmem_offset_base_colwise =
buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
float thread_amax_act = 0.0f; float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f; float thread_amax_gate = 0.0f;
...@@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 1. Read/Compute elements. Find MXFP8-block AMAX // 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll #pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_colwise = const size_t shmem_offset_colwise =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]); float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
...@@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_act = const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise; const size_t global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
...@@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const int scale_idx_gate = scale_idx + scale_stride_colwise / 2; // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate; scales_colwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 3. Scale elements // 3. Scale elements
#pragma unroll #pragma unroll
for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int shmem_offset_elt = const size_t shmem_offset_elt =
shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
OType2 out_pair; OType2 out_pair;
...@@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
float thread_amax_act = 0.0f; float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f; float thread_amax_gate = 0.0f;
...@@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)}; IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
...@@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else { } else {
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in_grad; Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act; Vec<IType, PACK_SIZE> in_act;
...@@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act = const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise; const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
...@@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate; scales_rowwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
} }
} }
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
...@@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X; const size_t global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM; const size_t buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
......
...@@ -80,33 +80,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -80,33 +80,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X;
const int tid_X_rowwise = threadIdx.x % THREADS_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X;
const int tid_Y_colwise = 0; const size_t tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x; const size_t tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise; const size_t thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; const size_t thread_offset_X_colwise = tid_X_colwise;
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise; const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
// helps resolving bank conflicts in shmem // helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
...@@ -139,7 +139,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -139,7 +139,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise); OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -173,20 +173,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -173,20 +173,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < STAGES; ++stage) { for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM; const size_t buff = stage % BUFFS_NUM;
const int next_stage = stage + 1; const size_t next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y; const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) { if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory. // Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to // I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>(); ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM; const size_t next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X; const size_t global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_DIM; const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
...@@ -205,7 +205,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -205,7 +205,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float thread_amax = 0.0f; float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f; thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y]; float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y]; IType in_colwise_IType[BUFF_DIM_Y];
...@@ -215,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -215,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType thread_amax_f16 = static_cast<IType>(0.0f); IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll #pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) { for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise]; in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
} }
...@@ -223,7 +223,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -223,7 +223,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else { } else {
#pragma unroll #pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) { for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]); float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) { if constexpr (IS_ACT) {
...@@ -263,9 +263,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -263,9 +263,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent = const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise; const size_t global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent; scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
...@@ -282,13 +283,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -282,13 +283,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
const float scaled_out = in * block_scale_inverse; const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out); out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
} }
} }
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f; thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X]; float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES]; Vec<IType, PACK_SIZE> in_cached[WAVES];
...@@ -301,9 +303,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -301,9 +303,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)}; IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements // Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll #pragma unroll
...@@ -319,9 +321,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -319,9 +321,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)}; IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
...@@ -354,9 +356,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -354,9 +356,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} else { } else {
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in; Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in; Vec<IType, PACK_SIZE> act_in;
...@@ -406,9 +408,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -406,9 +408,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent = const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise; const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent; scales_rowwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
...@@ -434,9 +436,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -434,9 +436,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
} }
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
} }
} }
...@@ -452,9 +454,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -452,9 +454,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X; const size_t global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM; const size_t buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
...@@ -485,18 +487,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -485,18 +487,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts // Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem); float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int shmem_thread_offset = const size_t shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll #pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) { for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e; const int j = w * PACK_SIZE + e;
const int shmem_elt_idx = swizzled_group_offset + e; const size_t shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
} }
} }
...@@ -504,15 +506,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -504,15 +506,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int i = 0; i < THREADS_Y; ++i) { for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32] // Add extra element offset per MXFP8 scaling block [1x32]
const int scaling_block = threadIdx.x / SCALE_DIM_X; const size_t scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias += thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
} }
} }
const int dbias_stride = cols; const size_t dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y; const size_t dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) { if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias; dbias_workspace[dbias_idx] = thread_partial_dbias;
...@@ -561,19 +563,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -561,19 +563,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const size_t cols) { const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const int thread_offset_Y = tid_Y; const size_t thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X; const size_t thread_offset_X = tid_X;
const int dbias_offset_Y = blockIdx.y + tid_Y; const size_t dbias_offset_Y = blockIdx.y + tid_Y;
const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const bool col_out_of_bounds = my_column >= cols; const bool col_out_of_bounds = my_column >= cols;
const int dbias_stride = cols; const size_t dbias_stride = cols;
float partial_dbias = 0.f; float partial_dbias = 0.f;
...@@ -588,7 +590,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -588,7 +590,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
__shared__ alignas(TMA_SHMEM_ALIGNMENT) __shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -600,13 +602,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -600,13 +602,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
int parity = 0; int parity = 0;
const int chunk_offset_Y = block_offset_Y; const size_t chunk_offset_Y = block_offset_Y;
const int chunk_offset_X = block_offset_X; const size_t chunk_offset_X = block_offset_X;
#pragma unroll #pragma unroll
for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const int chunk_stage_offset_X = chunk_offset_X; const size_t chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
...@@ -621,13 +623,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -621,13 +623,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { for (int iter = 0; iter < FP8_ITERATIONS; ++iter) {
const int buff = iter % FP8_BUFFERS_NUM; const size_t buff = iter % FP8_BUFFERS_NUM;
const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y;
if (next_iter < FP8_ITERATIONS) { if (next_iter < FP8_ITERATIONS) {
const int next_buff = next_iter % FP8_BUFFERS_NUM; const size_t next_buff = next_iter % FP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
...@@ -644,9 +646,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -644,9 +646,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage; const size_t stage_offset_Y = stage;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X; const size_t shmem_offset_x = thread_offset_X;
const size_t row = row_base + shmem_offset_y; const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = row >= rows; const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;
...@@ -685,8 +687,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -685,8 +687,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x, reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff])); chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
...@@ -704,8 +706,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -704,8 +706,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
parity ^= 1; parity ^= 1;
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
const int dbias_offset_X = my_column; const size_t dbias_offset_X = my_column;
const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
if (!col_out_of_bounds) { if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias; dbias_workspace[dbias_offset] = partial_dbias;
} }
...@@ -747,7 +749,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -747,7 +749,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const IType *input = input_ptr + block_offset; const IType *input = input_ptr + block_offset;
OType *output = output_ptr + block_offset; OType *output = output_ptr + block_offset;
...@@ -758,8 +760,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -758,8 +760,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
__shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -775,12 +777,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -775,12 +777,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#pragma unroll #pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) { for (int iter = 0; iter < ITERATIONS; ++iter) {
const int buff = iter % SHMEM_BUFFERS; const size_t buff = iter % SHMEM_BUFFERS;
const int it_offset = iter * SHMEM_DIM; const size_t it_offset = iter * SHMEM_DIM;
const int next_iter = iter + 1; const size_t next_iter = iter + 1;
const int next_buff = next_iter % SHMEM_BUFFERS; const size_t next_buff = next_iter % SHMEM_BUFFERS;
const int next_iter_offset = next_iter * SHMEM_DIM; const size_t next_iter_offset = next_iter * SHMEM_DIM;
if (next_iter < ITERATIONS) { if (next_iter < ITERATIONS) {
copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN,
...@@ -794,7 +796,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -794,7 +796,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#pragma unroll #pragma unroll
for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) {
const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
float elt = static_cast<float>(in_sh[buff][shmem_offset]); float elt = static_cast<float>(in_sh[buff][shmem_offset]);
if constexpr (IS_ACT) { if constexpr (IS_ACT) {
elt = OP(elt, {}); elt = OP(elt, {});
...@@ -847,12 +849,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -847,12 +849,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; constexpr size_t DBIAS_THREADS_PER_BLOCK = 256;
template <int nvec, typename OType> template <int nvec, typename OType>
__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const int cols) { const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>; using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>; using OutputVec = Vec<OType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) { if (thread_id * nvec >= cols) {
return; return;
...@@ -883,8 +885,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) ...@@ -883,8 +885,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
template <typename IType> template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) { cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64 constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec);
...@@ -1244,8 +1246,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) { ...@@ -1244,8 +1246,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) { bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim(); const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16; constexpr size_t TMA_bytes = 16;
const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0; return cols % alignment_requirement == 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment