Commit 331f2fc4 authored by wenjh's avatar wenjh
Browse files

Resolve merge issue from nv of vector blockwise


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 1f9c104b
......@@ -504,7 +504,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero, kThreadsPerWarp);
#else
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
#endif
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
......@@ -528,9 +532,15 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
OVec output_vec;
#pragma unroll
for (int i = 0; i < kThreadTileCol; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f,
fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]))));
} else {
output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
}
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g);
......@@ -558,9 +568,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
bool return_rowwise = rowwise_option == FP8BlockwiseRowwiseOption::ROWWISE;
bool return_columnwise_transpose =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_TRANSPOSE;
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
bool return_columnwise_compact =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
......@@ -724,7 +736,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
// Step 3: Transpose, cast and store to output_t
if (return_columnwise_transpose) {
if (return_columnwise_gemm_ready) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
......@@ -825,6 +837,113 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
}
}
}
// Step 4 (return_columnwise_compact): cast in 128x1 style and store to output, skip transpose
if (return_columnwise_compact) {
// thread tile should be 4x16, 16 means 8 smem reads
constexpr int kThreadTileRow = kTileDim / kThreadsPerWarp;
constexpr int kThreadTileCol = kNVecOut;
using RegVec = Vec<IType, kThreadTileCol>;
using RegScaleVec = Vec<CType, kThreadTileCol>;
constexpr int num_smem_reads = kNVecOut / kNVecSMem;
// c_stride will not be used here because we only have one iteration
// constexpr int c_stride = kThreadTileCol * kNumWarps / kNVecSMem;
constexpr int num_iterations =
kTileDim / (kNumWarps * kThreadTileCol); // should be only one iteration
static_assert(num_iterations == 1,
"num_iterations should be 1 for columnwise non-transpose case");
const int thr_idx_in_warp = threadIdx.x % kThreadsPerWarp;
const int warp_idx = threadIdx.x / kThreadsPerWarp;
const int r_s = thr_idx_in_warp * kThreadTileRow; // Row in shared memory
int c_s = warp_idx * num_smem_reads; // Column in shared memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
const size_t num_ele = c_g < row_length
? min(static_cast<size_t>(kThreadTileCol), row_length - c_g)
: 0; // For not aligned case
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
RegVec reg_vec[kThreadTileRow];
RegScaleVec thr_scale;
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
int r = r_s + i;
#pragma unroll
for (int j = 0; j < num_smem_reads; ++j) {
int c = c_s + j;
SMemVec smem_vec = smem[c * kTileDim + r];
// copy smem_vec to reg vec with its elements
#pragma unroll
for (int k = 0; k < kNVecSMem; ++k) {
reg_vec[i].data.elt[j * kNVecSMem + k] = smem_vec.data.elt[k];
}
}
}
#pragma unroll
for (int reg_idx = 0; reg_idx < kThreadTileCol; ++reg_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kThreadTileRow; ++i) {
amax = fmaxf(amax, fabsf(reg_vec[i].data.elt[reg_idx]));
}
// Step 3.3: Reduce amax
const bool is_src_lane = thr_idx_in_warp == 0;
amax = warp_reduce_max<kThreadsPerWarp>(amax);
constexpr int lane_zero = 0;
#ifdef __HIP_PLATFORM_AMD__
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero, kThreadsPerWarp);
#else
amax = __shfl_sync(0xFFFFFFFF, amax, lane_zero);
#endif
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
thr_scale.data.elt[reg_idx] = scale;
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (c_g + reg_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y);
size_t col_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + reg_idx;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
for (int row_idx = 0; row_idx < kThreadTileRow; ++row_idx) {
OType* output_g =
&output_t[(r_g + row_idx) * row_length + c_g]; // Output address in global memory
OVec output_vec;
#pragma unroll
for (int i = 0; i < kThreadTileCol; ++i) {
if constexpr (std::is_same_v<OType, int8_t>) {
output_vec.data.elt[i] = static_cast<OType>(lroundf(
fmaxf(-127.0f,
fminf(127.0f, static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]))));
} else {
output_vec.data.elt[i] = static_cast<OType>(
static_cast<CType>(reg_vec[row_idx].data.elt[i]) * thr_scale.data.elt[i]);
}
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g + row_idx < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory
// this section shouldn't matter since we only have one iteration
}
}
}
#endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment