Unverified Commit 3b4366be authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

Fix CI failures for UB overlap changes (#2149)


Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
parent 67fcc152
......@@ -264,7 +264,11 @@ def _train(opts):
[batched_size, hidden_size],
tp_size,
quantization_modes=[
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
(
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
],
dtype=torch.bfloat16,
bootstrap_backend=opts.bootstrap_backend,
......
......@@ -420,10 +420,14 @@ def _train(opts):
}
quantization_modes = [
UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE
(
te.module.base.UserBufferQuantizationMode.FP8
if opts.fp8
else te.module.base.UserBufferQuantizationMode.NONE
)
]
if opts.first_last_layers_bf16 and opts.fp8:
quantization_modes.append(UserBufferQuantizationMode.NONE)
quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE)
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
......
......@@ -508,9 +508,9 @@ def main() -> None:
torch.distributed.get_world_size(group),
quantization_modes=[
(
UserBufferQuantizationMode.FP8
te.module.base.UserBufferQuantizationMode.FP8
if model_config.quantization is not None
else UserBufferQuantizationMode.NONE
else te.module.base.UserBufferQuantizationMode.NONE
)
],
dtype=model_config.dtype,
......
......@@ -473,7 +473,7 @@ def initialize_ub(
fp8_buf = (name in layers_all_gather_overlap) or (
user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"]
)
ub_cfg.update(ub_cfgs[name])
ub_cfg.update(user_ub_cfg[name])
ub_cfg["fp8_buf"] = fp8_buf
add_ub(name, quantization_mode, **ub_cfg)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment