Unverified Commit 3b4366be authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

Fix CI failures for UB overlap changes (#2149)


Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
parent 67fcc152
...@@ -264,7 +264,11 @@ def _train(opts): ...@@ -264,7 +264,11 @@ def _train(opts):
[batched_size, hidden_size], [batched_size, hidden_size],
tp_size, tp_size,
quantization_modes=[ quantization_modes=[
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE (
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
], ],
dtype=torch.bfloat16, dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend, bootstrap_backend=opts.bootstrap_backend,
......
...@@ -420,10 +420,14 @@ def _train(opts): ...@@ -420,10 +420,14 @@ def _train(opts):
} }
quantization_modes = [ quantization_modes = [
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE (
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
] ]
if opts.first_last_layers_bf16 and opts.fp8: if opts.first_last_layers_bf16 and opts.fp8:
quantization_modes.append(UserBufferQuantizationMode.NONE) quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE)
te.module.base.initialize_ub( te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
......
...@@ -508,9 +508,9 @@ def main() -> None: ...@@ -508,9 +508,9 @@ def main() -> None:
torch.distributed.get_world_size(group), torch.distributed.get_world_size(group),
quantization_modes=[ quantization_modes=[
( (
UserBufferQuantizationMode.FP8 te.module.base.UserBufferQuantizationMode.FP8
if model_config.quantization is not None if model_config.quantization is not None
else UserBufferQuantizationMode.NONE else te.module.base.UserBufferQuantizationMode.NONE
) )
], ],
dtype=model_config.dtype, dtype=model_config.dtype,
......
...@@ -473,7 +473,7 @@ def initialize_ub( ...@@ -473,7 +473,7 @@ def initialize_ub(
fp8_buf = (name in layers_all_gather_overlap) or ( fp8_buf = (name in layers_all_gather_overlap) or (
user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
) )
ub_cfg.update(ub_cfgs[name]) ub_cfg.update(user_ub_cfg[name])
ub_cfg["fp8_buf"] = fp8_buf ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, quantization_mode, **ub_cfg) add_ub(name, quantization_mode, **ub_cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment