Commit 40a4d896 authored by wenjh's avatar wenjh
Browse files

Fix kernel crash on block_len=64


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent b944277c
......@@ -570,7 +570,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank]
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True}
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": False}
# Create model with FP8 weights
with te.fp8.fp8_model_init(
......
......@@ -248,7 +248,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
using transformer_engine::Vec;
static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType);
constexpr int kNumOutputElemsPerBank = 2 / sizeof(OType);
constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment