Commit 5b82e699 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 9a815d0b 7f946529
...@@ -167,7 +167,7 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase) ...@@ -167,7 +167,7 @@ class TestFP8BlockScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase)
dtype=dtype, dtype=dtype,
y_error=0.9, y_error=0.9,
ln_out_error=0.5, ln_out_error=0.5,
dgrad_error=1.5, dgrad_error=1,
wgrad_error=1, wgrad_error=1,
bgrad_error=0.5, bgrad_error=0.5,
recipe1_golden_tensors=None, recipe1_golden_tensors=None,
......
...@@ -116,7 +116,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -116,7 +116,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset && if (h_in_input < h && w_in_input < w && idx_in_input >= start_offset &&
idx_in_input < end_offset) { idx_in_input < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale; float inp = static_cast<float>(input_minus_offset[idx_in_input]) * scale;
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp); if constexpr(std::is_same_v<OType, int8_t>) {
smem[h_in_smem][w_in_smem] = static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, inp))));
}
else {
smem[h_in_smem][w_in_smem] = static_cast<OType>(inp);
}
skip_store = false; skip_store = false;
} }
} }
......
...@@ -431,9 +431,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -431,9 +431,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
for (int j = 0; j < THREAD_TILE_DIM_X; j++) { for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output // Step 3: Store cast output
CType scale_data = block_tile_scale; CType scale_data = block_tile_scale;
OType scaled_elt = 0;
OType scaled_elt = if constexpr(std::is_same_v<OType, int8_t>) {
scaled_elt =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
}
else {
scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data); static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt; tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile // Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) { if constexpr (kReturnTranspose) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment