Unverified Commit d90ced7c authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

Add support for overlapping wgrad NCCL AG with dgrad GEMM (#1849)



* Add support for overlapping wgrad NCCL AG with dgrad GEMM
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* Remove unused wait on memcpy API from UB
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* Add better commenting to MXFP8 overlap
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

---------
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
Co-authored-by: default avatardastokes <dastokes@dastokes-dvt-01.nvidia.com>
parent ecaf3e21
......@@ -66,6 +66,11 @@ def setup_common_extension() -> CMakeExtension:
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
if nvte_cmake_extra_args:
cmake_flags.extend(nvte_cmake_extra_args.split())
# Project directory root
root_path = Path(__file__).resolve().parent
......
......@@ -273,7 +273,9 @@ def _main(opts):
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
tp_group = dist.new_group(backend="nccl")
tp_group = dist.new_group(
backend="nccl", pg_options=dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
)
tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group)
dist_print(
......
......@@ -323,6 +323,7 @@ def _train(opts):
new_group_kwargs = {
"backend": "nccl",
"ranks": tp_rank_list,
"pg_options": dist.ProcessGroupNCCL.Options(is_high_priority_stream=True),
}
else:
opts.tp = WORLD_SIZE
......
......@@ -430,6 +430,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream();
}; // CommOverlap
class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
......@@ -449,6 +451,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream();
}; // CommOverlapP2P
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
......@@ -216,6 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
at::Stream CommOverlap::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device());
}
/***************************************************************************************************
* CommOverlapP2P
**************************************************************************************************/
......@@ -300,3 +304,7 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
const auto dtype = transformer_engine::pytorch::GetATenDType(_ubuf.dtype());
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
}
at::Stream CommOverlapP2P::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device());
}
......@@ -385,7 +385,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlap::get_communication_stream);
py::class_<CommOverlapP2P, std::shared_ptr<CommOverlapP2P>,
transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>(
......@@ -402,5 +403,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"),
py::arg("local_chunk") = false)
.def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false,
py::arg("shape") = std::nullopt);
py::arg("shape") = std::nullopt)
.def("get_communication_stream", &CommOverlapP2P::get_communication_stream);
}
......@@ -743,6 +743,31 @@ class _LayerNormLinear(torch.autograd.Function):
wgrad = None
if ctx.requires_wgrad:
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_grad_output_work = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_grad_output_work.wait()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
......@@ -757,22 +782,6 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
......
......@@ -832,17 +832,6 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_wgrad = None
if ctx.fc2_weight_requires_grad:
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
......@@ -851,14 +840,33 @@ class _LayerNormMLP(torch.autograd.Function):
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
async_op=True,
quantizer=ctx.fc2_grad_output_quantizer,
)
# Synchronize with the main stream
mxfp8_fc2_grad_output_work.wait()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
......
......@@ -689,14 +689,23 @@ class _Linear(torch.autograd.Function):
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
grad_output, grad_output_work = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
)
# Synchronize with the main stream
grad_output_work.wait()
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
......
......@@ -407,16 +407,26 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# allow reusing the grad output that was gathered for
# the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM.
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dy, _ = gather_along_first_dim(
grad_output,
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_comm_dgrad.get_communication_stream()
with torch.cuda.stream(dgrad_comm_stream):
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather
# This ensures that we don't start until all communication for the dgrad GEMM is complete
dy, dy_work = gather_along_first_dim(
dy_local,
tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
)
# Synchronize with the main stream
dy_work.wait()
if tensor_parallel_mode == "column":
dy = dy_local
if dy is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment