Unverified Commit 077e26c3 authored by Daniel Stokes's avatar Daniel Stokes Committed by GitHub
Browse files

Use userbuffers for MXFP8 wgrad all-gather overlap (#1982)



* fix: Add stream synchronization before destroying MPI communicator (#1979)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* feat: Implement column-wise userbuffer overlap for comm+GEMM operations

Add support for overlapping column-wise allgather communication with GEMM
operations to improve training performance:

* **Core infrastructure changes:**
  - Update bulk_overlap_columnwise_ag() to accept explicit stream parameter
  - Modify userbuffers send/recv loops to use rank-ordered iteration
  - Add userbuffers_send_all/recv_all function declarations

* **Python integration:**
  - Add bulk_overlap_ag_with_external_gemm() C++ extension function
  - Expose new overlap function via pybind11 bindings
  - Update overlap method configurations to include more ring_exchange ops

* **LayerNorm MLP optimization:**
  - Enable column-wise quantization for FC2 gradient output
  - Implement overlap of allgather communication with FC2 DGRAD GEMM
  - Use fill_userbuffers_buffer_for_all_gather for efficient buffering

This optimization allows overlapping communication and computation phases
more effectively, reducing training wall-clock time by hiding allgather
latency behind GEMM execution.
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Working userbuffer overlapping API
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Fix overwriting bulk overlap UB object for layernormLinear
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Update external overlap to use tp size instead of nvsize to determine number of copies
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Fix linter error
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Explanatory comments of overlap logic
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Fix the UB fused ops tests
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

* fix: Fix linter errors
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

---------
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 235c8d00
...@@ -519,6 +519,7 @@ def _train(opts): ...@@ -519,6 +519,7 @@ def _train(opts):
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
del test_graph del test_graph
torch.cuda.synchronize()
te.module.base.destroy_ub() te.module.base.destroy_ub()
dist_print("Destroying Userbuffers objects...", debug=True) dist_print("Destroying Userbuffers objects...", debug=True)
......
...@@ -138,6 +138,11 @@ CommOverlapCore::~CommOverlapCore() { ...@@ -138,6 +138,11 @@ CommOverlapCore::~CommOverlapCore() {
cudaStreamDestroy(_stream_compute[i]); cudaStreamDestroy(_stream_compute[i]);
} }
auto error = cudaGetLastError();
if (error != cudaSuccess) {
NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error));
}
if (_comm_created) { if (_comm_created) {
try { try {
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
...@@ -289,6 +294,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType ...@@ -289,6 +294,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
CommOverlapBase::~CommOverlapBase() { CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy); cudaEventDestroy(_start_d2dcopy);
cudaStreamSynchronize(_stream_comm);
cudaStreamDestroy(_stream_comm); cudaStreamDestroy(_stream_comm);
} }
...@@ -591,6 +597,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons ...@@ -591,6 +597,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs } // CommOverlapBase::split_overlap_rs
void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
int comm_bytes = _ubuf.bytes();
int comm_bytes_per_rank = comm_bytes / _tp_size;
// We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush
userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
send_stream);
userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm,
recv_stream);
for (auto stream : {send_stream, recv_stream}) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
// We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0));
}
}
/*************************************************************************************************** /***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange) * Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -2535,6 +2535,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2535,6 +2535,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
} }
} }
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = 1; j < tp_size; j++) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * tp_rank;
int recv_offset = dstoffset + bytes_per_slice * tp_rank;
userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream) {
for (int j = tp_size - 1; j > 0; j--) {
int i = (tp_rank + j) % tp_size;
int send_offset = srcoffset + bytes_per_slice * i;
int recv_offset = dstoffset + bytes_per_slice * i;
userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i,
stream);
}
}
// producer // producer
static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) {
// Decrement atomic val to signal current output tile finish // Decrement atomic val to signal current output tile finish
......
...@@ -304,4 +304,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp ...@@ -304,4 +304,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp
void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream);
void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes_per_slice, int tp_rank,
int tp_size, communicator *comm, cudaStream_t stream);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
...@@ -36,7 +36,8 @@ enum class CommOverlapAlgo { ...@@ -36,7 +36,8 @@ enum class CommOverlapAlgo {
SPLIT_PIPELINED_RS_P2P = 4, SPLIT_PIPELINED_RS_P2P = 4,
ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_RS = 5,
ATOMIC_GEMM_AG_P2P = 6, ATOMIC_GEMM_AG_P2P = 6,
ATOMIC_GEMM_RS_P2P = 7 ATOMIC_GEMM_RS_P2P = 7,
EXTERNAL_BULK_OVERLAP_AG = 8,
}; };
class CommOverlapCore { class CommOverlapCore {
...@@ -133,6 +134,11 @@ class CommOverlapCore { ...@@ -133,6 +134,11 @@ class CommOverlapCore {
cudaStream_t stream_main) { cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented."); NVTE_ERROR("Operation is not implemented.");
} }
virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) {
NVTE_ERROR("Operation is not implemented.");
}
}; // CommOverlapCore }; // CommOverlapCore
class CommOverlapBase : public CommOverlapCore { class CommOverlapBase : public CommOverlapCore {
...@@ -198,6 +204,9 @@ class CommOverlapBase : public CommOverlapCore { ...@@ -198,6 +204,9 @@ class CommOverlapBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override; cudaStream_t stream_main) override;
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override;
}; // CommOverlapBase }; // CommOverlapBase
class CommOverlapP2PBase : public CommOverlapCore { class CommOverlapP2PBase : public CommOverlapCore {
...@@ -277,6 +286,15 @@ class CommOverlapP2PBase : public CommOverlapCore { ...@@ -277,6 +286,15 @@ class CommOverlapP2PBase : public CommOverlapCore {
TensorWrapper &workspace, bool grad, bool accumulate, TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) override; cudaStream_t stream_main) override;
/*
** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object.
** The gemm for overlap_gemm is assumed to have been previously started.
*/
void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream,
cudaStream_t stream_main) override {
NVTE_ERROR("Operation not supported.");
}
}; // CommOverlapP2PBase }; // CommOverlapP2PBase
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -94,7 +94,9 @@ ...@@ -94,7 +94,9 @@
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \
.value("EXTERNAL_BULK_OVERLAP_AG", \
transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \
py::class_<transformer_engine::CommOverlapCore, \ py::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \ std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local()) \ pybind11::module_local()) \
......
...@@ -11,6 +11,10 @@ ...@@ -11,6 +11,10 @@
#include "common.h" #include "common.h"
class CommOverlapHelper;
class CommOverlap;
class CommOverlapP2P;
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
/*************************************************************************************************** /***************************************************************************************************
...@@ -419,6 +423,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k ...@@ -419,6 +423,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k
void nvshmem_finalize(); void nvshmem_finalize();
/***************************************************************************************************
* Comm+GEMM Overlap Wrappers
**************************************************************************************************/
void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream,
at::Stream recv_stream);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
/*************************************************************************************************** /***************************************************************************************************
...@@ -468,7 +479,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ...@@ -468,7 +479,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream(); std::pair<at::Stream, at::Stream> get_communication_stream();
}; // CommOverlap }; // CommOverlap
...@@ -489,7 +500,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ...@@ -489,7 +500,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream(); std::pair<at::Stream, at::Stream> get_communication_stream();
}; // CommOverlapP2P }; // CommOverlapP2P
......
...@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i ...@@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional<std::vector<i
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
} }
at::Stream CommOverlap::get_communication_stream() { std::pair<at::Stream, at::Stream> CommOverlap::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()); // Return the same stream for both send and recv
return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()),
at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())};
} }
/*************************************************************************************************** /***************************************************************************************************
...@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto ...@@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional<std::vecto
return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA)); return torch::from_blob(ubuf_ptr, *shape, at::dtype(dtype).device(torch::kCUDA));
} }
at::Stream CommOverlapP2P::get_communication_stream() { std::pair<at::Stream, at::Stream> CommOverlapP2P::get_communication_stream() {
return at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device()); return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()),
at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())};
}
void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm(
CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) {
auto main_stream = at::cuda::getCurrentCUDAStream();
allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream),
at::cuda::CUDAStream(recv_stream), main_stream);
} }
...@@ -374,6 +374,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -374,6 +374,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda,
"Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>()); "Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
// Comm+GEMM Overlap
m.def("bulk_overlap_ag_with_external_gemm",
&transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm,
"Bulk overlap All-Gather with a GEMM operation launched by another communicator",
py::call_guard<py::gil_scoped_release>(), py::arg("allgather_communicator"),
py::arg("send_stream"), py::arg("recv_stream"));
// Data structures // Data structures
py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
.def(py::init<>()) .def(py::init<>())
......
...@@ -151,7 +151,7 @@ def initialize_ub( ...@@ -151,7 +151,7 @@ def initialize_ub(
``` ```
for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad",
"proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad",
"fc2_fprop", "fc2_dgrad"]`. "fc2_fprop", "fc2_wgrad"]`.
bootstrap_backend : str = None bootstrap_backend : str = None
`torch.distributed` communication backend for the all-gather, broadcast and `torch.distributed` communication backend for the all-gather, broadcast and
barrier collectives during Userbuffers initialization. Not all backends are barrier collectives during Userbuffers initialization. Not all backends are
...@@ -250,22 +250,31 @@ def initialize_ub( ...@@ -250,22 +250,31 @@ def initialize_ub(
"qkv_fprop", "qkv_fprop",
"qkv_dgrad", "qkv_dgrad",
"proj_dgrad", "proj_dgrad",
"proj_wgrad",
"fc1_fprop", "fc1_fprop",
"fc1_dgrad", "fc1_dgrad",
"fc2_dgrad", "fc2_dgrad",
"fc2_wgrad",
] ]
layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"]
dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"]
# Default overlap methods for layers # Default overlap methods for layers
methods = { methods = {
"ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "ring_exchange": [
"qkv_fprop",
"fc1_fprop",
"proj_dgrad",
"fc2_dgrad",
],
"pipeline": ["proj_fprop", "fc2_fprop"], "pipeline": ["proj_fprop", "fc2_fprop"],
"bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"],
"external": ["proj_wgrad", "fc2_wgrad"],
} }
# AG-RS overlap pairs of layers forming a tensor-parallel block # AG-RS overlap pairs of layers forming a tensor-parallel block
ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"}
rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()}
external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"}
global layers_atomic_ring_exchange global layers_atomic_ring_exchange
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -319,7 +328,7 @@ def initialize_ub( ...@@ -319,7 +328,7 @@ def initialize_ub(
"Atomic GEMM uses a beta API from cublas and is not tested for all use cases." "Atomic GEMM uses a beta API from cublas and is not tested for all use cases."
) )
assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM."
if method == "bulk": if method in ("bulk", "external"):
warnings.warn( warnings.warn(
f"At {name}, atoimic GEMM not is supported for a bulk overlap." f"At {name}, atoimic GEMM not is supported for a bulk overlap."
"Defaulting to `atomic_gemm=False`." "Defaulting to `atomic_gemm=False`."
...@@ -348,6 +357,16 @@ def initialize_ub( ...@@ -348,6 +357,16 @@ def initialize_ub(
if atomic_gemm and method == "ring_exchange": if atomic_gemm and method == "ring_exchange":
assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message
if name in external_gemm_to_overlap:
assert method == "external", (
f"At {name}, `external` overlap method is specified, but the selected method is"
f" {method}"
)
assert external_gemm_to_overlap[name] in methods["ring_exchange"], (
f"At {name}, `external` overlap method is specified, but the external gemm"
f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method"
)
buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype
if method == "ring_exchange": if method == "ring_exchange":
ub_obj = tex.CommOverlapP2P( ub_obj = tex.CommOverlapP2P(
...@@ -396,7 +415,9 @@ def initialize_ub( ...@@ -396,7 +415,9 @@ def initialize_ub(
new_method = ub_cfgs[name]["method"] new_method = ub_cfgs[name]["method"]
methods[new_method].append(name) methods[new_method].append(name)
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: for name in (
methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"]
):
ub_cfg = get_default_config(name) ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs: if ub_cfgs is not None and name in ub_cfgs:
fp8_buf = (name in layers_all_gather_overlap) or ( fp8_buf = (name in layers_all_gather_overlap) or (
......
...@@ -758,27 +758,36 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -758,27 +758,36 @@ class _LayerNormLinear(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream # We use the send stream to copy into the userbuffers.
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() # This is the same stream that we will use to access the data in the AG,
with torch.cuda.stream(dgrad_comm_stream): # so we dont need to add any syncs yet.
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather with torch.cuda.stream(dgrad_send_stream):
# This ensures that we don't start until all communication for the dgrad GEMM is complete grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output, mxfp8_grad_output_work = gather_along_first_dim( ub_obj_overlap_wgrad,
grad_outputs[0], grad_outputs[0],
ctx.grad_output_quantizer,
ctx.tp_group, ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
) )
# Synchronize with the main stream
mxfp8_grad_output_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
# Prepare input tensor # Prepare input tensor
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
......
...@@ -851,26 +851,37 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -851,26 +851,37 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = (
ub_obj_fc2_dgrad.get_communication_stream()
)
ub_obj_fc2_wgrad = get_ub("fc2_wgrad")
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream() # We use the send stream to copy into the userbuffers.
with torch.cuda.stream(dgrad_comm_stream): # This is the same stream that we will use to access the data in the AG,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather # so we dont need to add any syncs yet.
# This ensures that we don't start until all communication for the dgrad GEMM is complete with torch.cuda.stream(dgrad_send_stream):
grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim( grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc2_wgrad,
grad_outputs[0], grad_outputs[0],
ctx.fc2_grad_output_quantizer,
ctx.tp_group, ctx.tp_group,
async_op=True,
quantizer=ctx.fc2_grad_output_quantizer,
) )
# Synchronize with the main stream
mxfp8_fc2_grad_output_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_fc2_wgrad, dgrad_send_stream, dgrad_recv_stream
)
# Prepare input tensor # Prepare input tensor
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
......
...@@ -745,26 +745,36 @@ class _Linear(torch.autograd.Function): ...@@ -745,26 +745,36 @@ class _Linear(torch.autograd.Function):
# Note: Synchronize tensor-parallel communication and # Note: Synchronize tensor-parallel communication and
# make sure required data is available # make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream()
# This object is separate from the ub_obj_wgrad object which is passed to the GEMM
ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad")
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() # We use the send stream to copy into the userbuffers.
with torch.cuda.stream(dgrad_comm_stream): # This is the same stream that we will use to access the data in the AG,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather # so we dont need to add any syncs yet.
# This ensures that we don't start until all communication for the dgrad GEMM is complete with torch.cuda.stream(dgrad_send_stream):
grad_output, grad_output_work = gather_along_first_dim( grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
grad_output_arg, grad_output_arg,
ctx.grad_output_quantizer,
ctx.tp_group, ctx.tp_group,
async_op=True,
quantizer=ctx.grad_output_quantizer,
) )
# Synchronize with the main stream
grad_output_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
tex.bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase): if isinstance(grad_output, QuantizedTensorBase):
......
...@@ -10,9 +10,9 @@ import warnings ...@@ -10,9 +10,9 @@ import warnings
import torch import torch
from transformer_engine_torch import CommOverlapType from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import gather_along_first_dim, get_distributed_world_size from ...distributed import get_distributed_world_size
from ...module.base import ( from ...module.base import (
fill_userbuffers_buffer_for_all_gather, fill_userbuffers_buffer_for_all_gather,
get_ub, get_ub,
...@@ -398,26 +398,35 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -398,26 +398,35 @@ class UserbuffersBackwardLinear(FusedOperation):
# Initialize grad output # Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output # UB does not support pipelined overlapping grad output
# all-gather with wgrad GEMM. Also, we can't # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we # convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered # can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around by explicitly # for the dgrad GEMM. We work around by explicitly
# overlapping the NCCL operation with the dgrad GEMM. # overlapping the AG operation with the dgrad GEMM.
# Get the communication stream from the dgrad GEMM to use for the AG
dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream()
ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad")
grad_output_quantizer.set_usage(rowwise=False, columnwise=True) grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Get the communication stream from the dgrad GEMM and set it as the current torch stream
dgrad_comm_stream = ub_comm_dgrad.get_communication_stream() # We use the send stream to copy into the userbuffers.
with torch.cuda.stream(dgrad_comm_stream): # This is the same stream that we will use to access the data in the AG,
# Syncs with the current stream (dgrad_comm_stream) before starting the all-gather # so we dont need to add any syncs yet.
# This ensures that we don't start until all communication for the dgrad GEMM is complete with torch.cuda.stream(dgrad_send_stream):
dy, dy_work = gather_along_first_dim( dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_overlap_wgrad,
dy_local, dy_local,
grad_output_quantizer,
tensor_parallel_group, tensor_parallel_group,
async_op=True,
quantizer=grad_output_quantizer,
) )
# Synchronize with the main stream
dy_work.wait() # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm
bulk_overlap_ag_with_external_gemm(
ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream
)
if tensor_parallel_mode == "column": if tensor_parallel_mode == "column":
dy = dy_local dy = dy_local
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment