Commit 1e018a45 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 4ef4eae6 40a4d896
...@@ -570,7 +570,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): ...@@ -570,7 +570,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group):
mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
mock_group = mock_groups[rank] mock_group = mock_groups[rank]
linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": False}
# Create model with FP8 weights # Create model with FP8 weights
with te.fp8.fp8_model_init( with te.fp8.fp8_model_init(
......
...@@ -248,7 +248,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) ...@@ -248,7 +248,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
using transformer_engine::Vec; using transformer_engine::Vec;
static_assert(sizeof(OType) == 1); static_assert(sizeof(OType) == 1);
constexpr int kNumOutputElemsPerBank = 4 / sizeof(OType); constexpr int kNumOutputElemsPerBank = 2 / sizeof(OType);
constexpr int kThreadsPerWarp = 32; constexpr int kThreadsPerWarp = 32;
constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp; constexpr int kLoopsPerRow = kTileDim64 / kThreadsPerWarp;
constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment