Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_silu);
using namespace transformer_engine;
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
#include <numeric>
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using namespace std::placeholders;
namespace transformer_engine {
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
bool ubuf_built_with_mpi() {
#ifdef NVTE_UB_WITH_MPI
return true;
#else
return false;
#endif
}
CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm) {
// Initialize userbuf communicator
if (!_comm_created) {
if (myrank == 0) {
printf("!!! [UB] Create Userbuffers Communicator\n");
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
#else
create_communicator_grouped2(&_ub_comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
allgather_handle, barrier_handle, 1, 1, tp_size, 1);
#endif
_comm_created = true;
}
_use_ce = static_cast<int>(use_ce);
_num_comm_sm = num_comm_sm;
_cga_size = comm_cga_size;
if (gemm_priority == 0 && comm_priority == 0) {
transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority);
} else {
_gemm_priority = gemm_priority;
_comm_priority = comm_priority;
}
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority));
_stream_compute.push_back(std::move(stream));
}
_num_splits = num_splits;
_rank = _ub_comm->myrank;
_tp_size = tp_size;
_tp_id = _rank % _tp_size;
// Set the number of SMs for GEMM with margin
int sm_count = transformer_engine::cuda::sm_count();
_math_sms = (set_sm_margin) ? sm_count - num_comm_sm : sm_count;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);
_atomic_gemm = atomic_gemm;
if (_atomic_gemm) {
void *counter_ptr;
size_t counter_bytes = _num_splits * 2 * sizeof(int32_t);
NVTE_CHECK_CUDA(cudaMalloc(&counter_ptr, counter_bytes));
NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 0, counter_bytes));
NVTE_CHECK_CUDA(cudaMemset(counter_ptr, 1, counter_bytes / 2));
_counter = TensorWrapper(counter_ptr, std::vector<size_t>{static_cast<size_t>(_num_splits * 2)},
DType::kInt32);
}
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
/*
Defining the launcher order between the communication and GEMM kernels
using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1.
The event is used to schedule the communication kernel before the GEMM.
This is needed only for Hopper, which uses persistent CTA execution.
*/
int max_connection = transformer_engine::getenv<int>("CUDA_DEVICE_MAX_CONNECTIONS", 8);
int runtime_version = 0;
cudaRuntimeGetVersion(&runtime_version);
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, 0);
if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) {
cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming);
} else {
_comm_launch_event = 0;
}
}
CommOverlapCore::~CommOverlapCore() {
cudaEventDestroy(_stop_comm);
cudaEventDestroy(_start_comm);
cudaEventDestroy(_stop_compute);
cudaEventDestroy(_start_compute);
if (_comm_launch_event) cudaEventDestroy(_comm_launch_event);
if (_atomic_gemm) cudaFree(_counter.dptr());
for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]);
if (_comm_created) {
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi(_ub_comm);
#else
destroy_communicator(_ub_comm);
#endif
_comm_created = false;
}
}
TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
TensorWrapper chunk;
for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) {
auto param_type = static_cast<NVTETensorParam>(param_id);
auto param = source.get_parameter(param_type);
auto param_dptr = reinterpret_cast<char *>(param.data_ptr);
auto param_dtype = static_cast<DType>(param.dtype);
auto param_shape = AS_VECTOR(param.shape);
if (param_dptr != nullptr) {
if (param_type == NVTETensorParam::kNVTERowwiseData ||
param_type == NVTETensorParam::kNVTEColumnwiseData) {
// Offset data pointer
param_dptr += chunk_offset * typeToSize(param_dtype);
param_shape = chunk_shape;
if (param_type == NVTETensorParam::kNVTEColumnwiseData &&
source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) {
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
auto last_dim = param_shape.back();
param_shape.pop_back();
param_shape.insert(param_shape.begin(), last_dim);
}
} else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING &&
(param_type == NVTETensorParam::kNVTERowwiseScaleInv ||
param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) {
// Calculate block scaling offset and size
auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? source.shape().data[0]
: source.columnwise_shape().data[0];
auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv)
? chunk_shape.front()
: chunk_shape.back();
auto chunk_scale_start = chunk_offset / 32;
auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32;
auto chunk_scale_size = chunk_scale_end - chunk_scale_start;
param_dptr += chunk_scale_start * typeToSize(param_dtype);
param_shape = std::vector<size_t>{chunk_scale_size};
}
// Set chunked source parameters into the chunked tensor output
chunk.set_parameter(param_type, reinterpret_cast<void *>(param_dptr), param_dtype,
param_shape);
}
}
return chunk;
}
TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source,
size_t chunk_offset,
const std::vector<size_t> &chunk_shape) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape);
// Update chunk with offset data pointers from the communication buffer
auto ubuf_ptr = reinterpret_cast<char *>(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size());
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(), chunk.shape());
}
if (chunk.columnwise_dptr() != nullptr) {
chunk.set_columnwise_data(reinterpret_cast<void *>(ubuf_ptr), chunk.dtype(),
chunk.columnwise_shape());
}
return chunk;
}
/***************************************************************************************************
* Comm+GEMM Overlap Base (Pipelined / Collective)
**************************************************************************************************/
CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
int myrank, int numranks, int mylocal, int numlocal, int mynode,
int numnodes, int tp_size, ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle, int num_splits, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool atomic_gemm,
bool rs_overlap_first_gemm)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false,
atomic_gemm) {
_rs_overlap_first_gemm = rs_overlap_first_gemm;
_rs_kernel_type = getenv<int>("NVTE_RS_STRIDED_ATOMIC", 0);
NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ",
"or 2 (multi-atomic).");
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype);
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0));
}
CommOverlapBase::~CommOverlapBase() {
cudaEventDestroy(_start_d2dcopy);
cudaStreamDestroy(_stream_comm);
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator,
CommOverlapType comm_type, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0));
// Communication: AG and RS
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
if (comm_type == CommOverlapType::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
if (_ubuf.element_size() == 1) {
assert(_ubuf_scale_inv_initialized);
comm_elements *= 2;
assert(rs_output.numel() == _ubuf.numel() / _tp_size);
assert(rs_output.size(0) == _ubuf.size(0) / _tp_size);
assert(rs_output.element_size() == 2);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0,
comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
} else {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm,
(cudaEvent_t)_comm_launch_event);
}
}
assert(pre_gelu_out.numel() == 0);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if (_comm_launch_event)
NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _comm_launch_event, 0));
nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb,
grad, workspace.data(), accumulate, use_split_accumulator, _math_sms,
_stream_compute[0]);
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa,
const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.dptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.dptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _num_splits, false, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0));
assert(pre_gelu_out.numel() == 0);
auto output_d = get_buffer_chunk_like(D, 0, {n, m});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(),
_stream_compute[0]);
for (int i = 0; i < _num_splits; i++) {
if (_rs_kernel_type == 1) {
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_atomic_fp8<fp8_type>(
rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits,
&counter_ptr[i], _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_num_splits, &counter_ptr[i], _ub_comm,
_stream_comm);
}
} else if (_rs_kernel_type == 2) {
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_strided_multiatomic_fp8<fp8_type>(
rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits,
counter_ptr, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m,
_num_splits, counter_ptr, _ub_comm,
_stream_comm);
}
break;
} else {
consumer(counter_ptr, i, _stream_comm);
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(rs_output_ptr, D.scale_inv(),
_ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_strided(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m,
_ub_comm, _stream_comm);
}
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
_ub_comm->sms = ori_sms;
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // split_overlap_rs
/*
** Split FPROP GEMM + ReduceScatter
*/
void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B,
bool transb, TensorWrapper &D, TensorWrapper &bias,
TensorWrapper &pre_gelu_out, TensorWrapper &workspace,
bool grad, bool accumulate, bool use_split_accumulator,
TensorWrapper &rs_output, cudaStream_t stream_main) {
// Get GEMM dimensions
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n = _ubuf.size(0);
size_t m_chunk = m / _num_splits;
size_t input_a_chunk_size = m_chunk * k;
size_t output_chunk_size = n * m_chunk;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_compute, 0));
assert(pre_gelu_out.numel() == 0);
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_rs_overlap_first_gemm) {
auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk});
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[0]);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
NVTE_CHECK_CUDA(
cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication chunk
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[last_compute_stream_id]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk,
n, m, _ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm);
}
} else {
for (int i = 0; i < _num_splits; i++) {
auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k});
auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk});
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[i % _stream_compute.size()]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0));
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits - 1) {
_ub_comm->sms = UB_MAX_SM;
}
if (_ubuf.element_size() == 1) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reducescatter2_userbuff_stridedoutput_fp8<fp8_type>(
rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m,
_ub_comm, _stream_comm););
} else {
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, _stream_comm);
}
rs_output_ptr += m_chunk * rs_output.element_size();
}
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0));
} // CommOverlapBase::split_overlap_rs
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle,
CommOverlapType comm_type, int num_max_streams,
int comm_cga_size, int gemm_priority, int comm_priority,
int num_comm_sm, bool set_sm_margin, bool use_ce,
bool atomic_gemm, bool aggregate)
: CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size,
allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size,
gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce,
atomic_gemm) {
_is_p2p = true;
_is_reduce_scatter = comm_type == CommOverlapType::RS;
_aggregate = aggregate;
// Create workspace tensor with userbuffer
NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!");
size_t buffer_bytes = buffer_shape[0] * buffer_shape[1] * typeToSize(buffer_dtype);
int buffer_chunk_bytes = buffer_bytes / tp_size;
_num_ubuf_chunks = tp_size;
if (_is_reduce_scatter) {
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1);
_num_ubuf_chunks = tp_size * 2 - 1;
}
void *buffer_ptr;
_ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true);
if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg);
_ubuf = TensorWrapper(buffer_ptr, {buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]},
buffer_dtype);
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(buffer_ptr);
for (int i = 0; i < _num_ubuf_chunks; i++) {
_ubufs.push_back(TensorWrapper(reinterpret_cast<void *>(ubuf_byte_ptr),
{buffer_shape[0] / tp_size, buffer_shape[1]}, buffer_dtype));
ubuf_byte_ptr += buffer_chunk_bytes;
}
_rank_round_tp = (_rank / _tp_size) * _tp_size;
_next_rank = (_tp_size + _rank + 1) % _tp_size + _rank_round_tp;
_prev_rank = (_tp_size + _rank + -1) % _tp_size + _rank_round_tp;
_self_chunk_id = _tp_id;
if (_atomic_gemm && !_is_reduce_scatter) {
_use_multiatomic_ag = getenv<bool>("NVTE_AG_P2P_MULTI_ATOMIC");
if (_use_multiatomic_ag) {
_use_ce = 0;
_ub_comm->push = 1;
if (_rank == 0) {
printf("!!userbuffers_sendrecv_multi_atomic_shuffle\n");
}
}
_self_chunk_id = 0;
NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t)));
}
for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority));
_stream_send.push_back(std::move(stream));
}
NVTE_CHECK_CUDA(
cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0));
NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0));
}
CommOverlapP2PBase::~CommOverlapP2PBase() {
cudaEventDestroy(_stop_recv);
cudaEventDestroy(_stop_send);
cudaStreamDestroy(_stream_recv);
for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]);
}
TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source,
size_t chunk_id) {
// Start with a chunk of the source tensor
auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape()));
// Update chunk with offset data pointers from the communication buffer
if (chunk.dptr() != nullptr) {
chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape());
}
if (chunk.columnwise_dptr() != nullptr) {
chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape());
}
return chunk;
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::atomic_gemm_overlap_ag(
const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t n_chunk = _ubufs[0].size(0);
assert(pre_gelu_out.numel() == 0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void *D_buffer_ptr;
int D_chunk_bytes = n_chunk * m * D.element_size();
NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main));
auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(),
D.scale_inv(), D.scale_inv_shape(), D.scaling_mode());
// Reset atomic counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, true, stream_main);
// Catch up the default torch stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape()));
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk});
for (int i = 0; i < _tp_size - 1; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = i;
int recv_chunk_id = i + 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
if (_use_multiatomic_ag) {
if (i == 0) {
_ub_comm->use_ce = 0;
userbuffers_sendrecv_multiatomic(_ub_reg, _ub_reg, comm_bytes, comm_bytes, comm_bytes,
_ub_comm, _next_rank, _prev_rank, _tp_size, counter_ptr,
true, _stream_recv);
}
} else {
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _next_rank,
_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank,
_stream_recv);
producer(counter_ptr, recv_chunk_id, _stream_recv);
}
if (i == 0) {
nvte_cublas_atomic_gemm(A.data(), input_b.data(), D_buffer.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(),
accumulate, use_split_accumulator, _math_sms, 0, _tp_size, false,
_counter.data(), stream_main);
}
}
// Store the input activation for backprop
if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_self_chunk_id].numel());
assert(B_copy.element_size() == _ubufs[_self_chunk_id].element_size());
NVTE_CHECK_CUDA(
cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(),
_ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char *src_ptr = reinterpret_cast<char *>(D_buffer.dptr());
NVTE_CHECK_CUDA(cudaMemcpyAsync(src_ptr + (D.numel() * D.element_size()), src_ptr, D_chunk_bytes,
cudaMemcpyDeviceToDevice, stream_main));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA(cudaMemcpyAsync(D.dptr(), src_ptr + D_chunk_bytes, D.numel() * D.element_size(),
cudaMemcpyDeviceToDevice, stream_main));
// Clean up buffer allocation
NVTE_CHECK_CUDA(cudaFreeAsync(D_buffer_ptr, stream_main));
_ub_comm->sms = ori_sms;
} // CommOverlapP2PBase::atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &B_copy,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get GEMM dimensions between TN and NN input layouts
const size_t m = (transa) ? A.size(0) : A.size(1);
const size_t k = (transa) ? A.size(1) : A.size(0);
const size_t n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const bool do_gelu = pre_gelu_out.numel() > 0;
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
if (_aggregate) {
const int num_steps = _tp_size / 2;
input_chunk_size *= 2;
output_chunk_size *= 2;
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
_stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
_stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp;
// Ring exchange of 2X inputs chunks
for (int i = 0; i < num_steps; i++) {
send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size;
recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size;
send_offset = comm_bytes * send_chunk_id;
recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk =
get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m});
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
}
} else {
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k});
auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m});
auto aux_chunk =
(do_gelu)
? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k})
: TensorWrapper(nullptr, std::vector<size_t>{0}, pre_gelu_out.dtype());
auto workspace_chunk = get_tensor_chunk(
workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms,
_stream_compute[i % _stream_compute.size()]);
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, _stream_send[0]);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, _stream_recv);
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0));
NVTE_CHECK_CUDA(
cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, _stream_send[0]));
}
}
}
_ub_comm->sms = ori_sms;
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
} // CommOverlapP2PBase::split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::atomic_gemm_overlap_rs(
const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad,
bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Reset counters
int *counter_ptr = reinterpret_cast<int *>(_counter.dptr());
reset_counters(counter_ptr, _tp_size, false, stream_main);
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape()));
nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(),
transa, transb, grad, workspace.data(), accumulate, use_split_accumulator,
_math_sms, 0, _tp_size, true, _counter.data(), stream_main);
// P2P communication chunk
for (int i = 1; i < _tp_size; i++) {
int send_chunk_id = i - 1;
int recv_chunk_id = send_chunk_id + _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int send_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
consumer(counter_ptr, send_chunk_id, _stream_recv);
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_recv);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
}
_ub_comm->sms = ori_sms;
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
const TensorWrapper &B, bool transb, TensorWrapper &D,
TensorWrapper &bias, TensorWrapper &pre_gelu_out,
TensorWrapper &workspace, bool grad, bool accumulate,
bool use_split_accumulator, TensorWrapper &rs_output,
cudaStream_t stream_main) {
int ori_sms = _ub_comm->sms;
_ub_comm->use_ce = _use_ce;
_ub_comm->sms = _num_comm_sm;
_ub_comm->cga_size = _cga_size;
// Get communication and GEMM input chunk sizes
size_t m = transa ? A.size(0) : A.size(1);
size_t k = transa ? A.size(1) : A.size(0);
size_t n_chunk = _ubufs[0].size(0);
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
// Get input and workspace data pointers
size_t input_chunk_size = n_chunk * k;
size_t output_chunk_size = n_chunk * m;
size_t workspace_size_chunk = workspace.numel() / _stream_compute.size();
// Catch up the main stream
NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (size_t i = 0; i < _stream_send.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0));
}
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0));
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0));
}
// GEMM and send/recv chunks
for (int i = 0; i < _tp_size; i++) {
// GEMM chunk
int stream_id = i % _stream_compute.size();
int input_b_chunk_id = (_tp_id + i + 1) % _tp_size;
auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k});
auto output_chunk = get_buffer_chunk_by_id(D, i);
auto workspace_chunk =
get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk});
nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(),
pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate,
use_split_accumulator, _math_sms, _stream_compute[stream_id]);
if (i > 0) {
// P2P communication chunk
int prev_stream_id = (i - 1) % _stream_compute.size();
int send_offset = comm_bytes * (i - 1);
int recv_offset = comm_bytes * (i - 1 + _tp_size);
int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp;
int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp;
NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0));
userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank,
_stream_send[prev_stream_id]);
userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank,
_stream_recv);
}
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0));
}
for (size_t i = 0; i < _stream_compute.size(); i++) {
NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i]));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0));
}
NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv));
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0));
// Reduce GEMM output chunks
char *reduce_buf_ptr = reinterpret_cast<char *>(_ubufs[_tp_size - 1].dptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) {
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.dptr());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
D.dtype(), fp8_type,
reduce_fp8_in_bf16_out<fp8_type>(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size,
_ubufs[0].numel(), stream_main););
} else {
reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main);
}
_ub_comm->sms = ori_sms;
}
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "ipcsocket.h"
#include <errno.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#define IPC_MAX_MSGLEN 4096
void ipc_warn(const char *format, ...) {
char buffer[IPC_MAX_MSGLEN];
va_list args;
va_start(args, format);
vsnprintf(buffer, IPC_MAX_MSGLEN - 1, format, args);
snprintf(buffer + strlen(buffer), IPC_MAX_MSGLEN - strlen(buffer) - 1, " : %s (%d)\n",
strerror(errno), errno);
fflush(stdout);
fputs(buffer, stderr);
fflush(NULL);
va_end(args);
}
static const char *ipcSocketResultStrings[static_cast<int>(ipcSocketNumResults)] = {
"Success", "Unhandled CUDA error", "System error", "Internal error",
"Invalid argument", "Invalid usage", "Remote error", "In progress",
};
const char *ipcSocketGetErrorString(ipcSocketResult_t res) {
return ipcSocketResultStrings[static_cast<int>(res)];
}
#define USE_ABSTRACT_SOCKET // Enable Linux abstract socket naming
#define IPC_SOCKNAME_STR "/tmp/ub-ipc-socket-%d-%lx"
/*
* Create a Unix Domain Socket
*/
ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag) {
int fd = -1;
struct sockaddr_un cliaddr;
char temp[IPC_SOCKNAME_LEN] = "";
if (handle == NULL) {
return ipcSocketInternalError;
}
handle->fd = -1;
handle->socketName[0] = '\0';
if ((fd = socket(AF_UNIX, SOCK_DGRAM, 0)) < 0) {
ipc_warn("UDS: Socket creation error");
return ipcSocketSystemError;
}
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
// Create unique name for the socket.
size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
errno = ENAMETOOLONG;
ipc_warn("UDS: Cannot bind provided name to socket. Name too large");
return ipcSocketInternalError;
}
strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#else
unlink(temp);
#endif
if (bind(fd, (struct sockaddr *)&cliaddr, sizeof(cliaddr)) < 0) {
ipc_warn("UDS: Binding to socket %s failed", temp);
close(fd);
return ipcSocketSystemError;
}
handle->fd = fd;
strcpy(handle->socketName, temp); // NOLINT(*)
handle->abortFlag = abortFlag;
// Mark socket as non-blocking
if (handle->abortFlag) {
int flags = fcntl(fd, F_GETFL);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
return ipcSocketSuccess;
}
ipcSocketResult_t ipcSocketGetFd(struct IpcSocketHandle *handle, int *fd) {
if (handle == NULL) {
errno = EINVAL;
ipc_warn("ipcSocketSocketGetFd: pass NULL socket");
return ipcSocketInvalidArgument;
}
if (fd) *fd = handle->fd;
return ipcSocketSuccess;
}
ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle) {
if (handle == NULL) {
return ipcSocketInternalError;
}
if (handle->fd <= 0) {
return ipcSocketSuccess;
}
#ifndef USE_ABSTRACT_SOCKET
if (handle->socketName[0] != '\0') {
unlink(handle->socketName);
}
#endif
close(handle->fd);
return ipcSocketSuccess;
}
ipcSocketResult_t ipcSocketRecvMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, int *recvFd) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
// Union to guarantee alignment requirements for control array
union {
struct cmsghdr cm;
char control[CMSG_SPACE(sizeof(int))];
} control_un;
struct cmsghdr *cmptr;
char dummy_buffer[1];
int ret;
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof(control_un.control);
if (hdr == NULL) {
iov[0].iov_base = reinterpret_cast<void *>(dummy_buffer);
iov[0].iov_len = sizeof(dummy_buffer);
} else {
iov[0].iov_base = hdr;
iov[0].iov_len = hdrLen;
}
msg.msg_iov = iov;
msg.msg_iovlen = 1;
while ((ret = recvmsg(handle->fd, &msg, 0)) <= 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
ipc_warn("UDS: Receiving data over socket failed");
return ipcSocketSystemError;
}
if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
}
if (recvFd != NULL) {
if (((cmptr = CMSG_FIRSTHDR(&msg)) != NULL) && (cmptr->cmsg_len == CMSG_LEN(sizeof(int)))) {
if ((cmptr->cmsg_level != SOL_SOCKET) || (cmptr->cmsg_type != SCM_RIGHTS)) {
errno = EBADMSG;
ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
}
memmove(recvFd, CMSG_DATA(cmptr), sizeof(*recvFd));
} else {
errno = ENOMSG;
ipc_warn("UDS: Receiving data over socket %s failed", handle->socketName);
return ipcSocketSystemError;
}
} else {
errno = EINVAL;
ipc_warn("UDS: File descriptor pointer cannot be NULL");
return ipcSocketInvalidArgument;
}
return ipcSocketSuccess;
}
ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *recvFd) {
return ipcSocketRecvMsg(handle, NULL, 0, recvFd);
}
ipcSocketResult_t ipcSocketSendMsg(IpcSocketHandle *handle, void *hdr, int hdrLen, const int sendFd,
int rank, uint64_t hash) {
struct msghdr msg = {0, 0, 0, 0, 0, 0, 0};
struct iovec iov[1];
char temp[IPC_SOCKNAME_LEN];
union {
struct cmsghdr cm;
char control[CMSG_SPACE(sizeof(int))];
} control_un;
struct cmsghdr *cmptr;
char dummy_buffer[1];
struct sockaddr_un cliaddr;
// Construct client address to send this shareable handle to
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
size_t len = snprintf(temp, IPC_SOCKNAME_LEN, IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
errno = ENAMETOOLONG;
ipc_warn("UDS: Cannot connect to provided name for socket. Name too large");
return ipcSocketInternalError;
}
(void)strncpy(cliaddr.sun_path, temp, len);
#ifdef USE_ABSTRACT_SOCKET
cliaddr.sun_path[0] = '\0'; // Linux abstract socket trick
#endif
if (sendFd != -1) {
msg.msg_control = control_un.control;
msg.msg_controllen = sizeof(control_un.control);
cmptr = CMSG_FIRSTHDR(&msg);
cmptr->cmsg_len = CMSG_LEN(sizeof(int));
cmptr->cmsg_level = SOL_SOCKET;
cmptr->cmsg_type = SCM_RIGHTS;
memmove(CMSG_DATA(cmptr), &sendFd, sizeof(sendFd));
}
msg.msg_name = reinterpret_cast<void *>(&cliaddr);
msg.msg_namelen = sizeof(struct sockaddr_un);
if (hdr == NULL) {
iov[0].iov_base = reinterpret_cast<void *>(dummy_buffer);
iov[0].iov_len = sizeof(dummy_buffer);
} else {
iov[0].iov_base = hdr;
iov[0].iov_len = hdrLen;
}
msg.msg_iov = iov;
msg.msg_iovlen = 1;
msg.msg_flags = 0;
ssize_t sendResult;
while ((sendResult = sendmsg(handle->fd, &msg, 0)) < 0) {
if (errno != EAGAIN && errno != EWOULDBLOCK && errno != EINTR) {
ipc_warn("UDS: Sending data over socket %s failed", temp);
return ipcSocketSystemError;
}
if (handle->abortFlag && *handle->abortFlag) return ipcSocketInternalError;
}
return ipcSocketSuccess;
}
ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int sendFd, int rank,
uint64_t hash) {
return ipcSocketSendMsg(handle, NULL, 0, sendFd, rank, hash);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <memory.h>
#include <stdio.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <unistd.h>
typedef enum {
ipcSocketSuccess = 0,
ipcSocketUnhandledCudaError = 1,
ipcSocketSystemError = 2,
ipcSocketInternalError = 3,
ipcSocketInvalidArgument = 4,
ipcSocketInvalidUsage = 5,
ipcSocketRemoteError = 6,
ipcSocketInProgress = 7,
ipcSocketNumResults = 8
} ipcSocketResult_t;
const char *ipcSocketGetErrorString(ipcSocketResult_t res);
#define IPC_SOCKNAME_LEN 64
struct IpcSocketHandle {
int fd;
char socketName[IPC_SOCKNAME_LEN];
volatile uint32_t *abortFlag;
};
ipcSocketResult_t ipcSocketInit(IpcSocketHandle *handle, int rank, uint64_t hash,
volatile uint32_t *abortFlag);
ipcSocketResult_t ipcSocketClose(IpcSocketHandle *handle);
ipcSocketResult_t ipcSocketGetFd(IpcSocketHandle *handle, int *fd);
ipcSocketResult_t ipcSocketRecvFd(IpcSocketHandle *handle, int *fd);
ipcSocketResult_t ipcSocketSendFd(IpcSocketHandle *handle, const int fd, int rank, uint64_t hash);
#endif /* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <inttypes.h>
#include <math.h>
#include <sched.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <chrono>
#include <iostream>
#include <map>
#include <utility>
#include "common/util/cuda_driver.h"
#include "common/util/cuda_nvml.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "ipcsocket.h"
#include "userbuffers.h"
#ifdef NVTE_UB_WITH_MPI
static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD;
static MPI_Comm EXT_COMM_INTRA;
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%d in function %s: %s", __FILE__, __LINE__, \
__func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
void ub_mpi_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
ExtComm comm) {
int numranks;
UB_MPI_CHECK(MPI_Comm_size(comm, &numranks));
assert(globalbytes == numranks * localbytes);
UB_MPI_CHECK(
MPI_Allgather(localdata, localbytes, MPI_BYTE, globaldata, localbytes, MPI_BYTE, comm));
}
void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
#else
#define EXT_COMM_WORLD "world"
#define EXT_COMM_INTRA "intra"
#endif
#define MULTICAST_GB_TOTAL 512
#if CUDART_VERSION < 12030
// MNNVL: FABRIC handle support lifted from CUDA 12.3
#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL)
#define CU_IPC_HANDLE_SIZE 64
typedef struct CUmemFabricHandle_st {
unsigned char data[CU_IPC_HANDLE_SIZE];
} CUmemFabricHandle_v1;
typedef CUmemFabricHandle_v1 CUmemFabricHandle;
#endif
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
#define IPCCHECK(cmd) \
do { \
ipcSocketResult_t r = cmd; \
if (r != ipcSocketSuccess) { \
printf("Failed, UDS error %s:%d '%s'\n", __FILE__, __LINE__, ipcSocketGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define IPCCHECKGOTO(call, RES, label) \
do { \
RES = call; \
if (RES != ipcSocketSuccess && RES != ipcSocketInProgress) { \
goto label; \
} \
} while (0);
bool has_mnnvl_fabric(int device_id) {
#if CUDA_VERSION < 12040
if (getenv("NVTE_UBDEBUG")) {
printf(
"TransformerEngine does not support multi-node NVLINK "
"since it was not built with CUDA version >= 12.4.\n");
}
return false;
#else
bool mnnvl_fabric_support = false;
CUdevice dev;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id);
int fabric_handle_supported = 0;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &fabric_handle_supported,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev);
if (fabric_handle_supported) {
NVTE_CALL_CHECK_CUDA_NVML(nvmlInit_v2);
nvmlDevice_t local_device;
NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetHandleByIndex_v2, device_id, &local_device);
nvmlGpuFabricInfoV_t fabricInfo = {};
fabricInfo.version = nvmlGpuFabricInfo_v2;
fabricInfo.clusterUuid[0] = '\0';
NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetGpuFabricInfoV, local_device, &fabricInfo);
NVTE_CALL_CHECK_CUDA_NVML(nvmlShutdown);
if (fabricInfo.state >= NVML_GPU_FABRIC_STATE_COMPLETED && fabricInfo.clusterUuid[0] != '\0') {
mnnvl_fabric_support = true;
}
}
if (getenv("NVTE_UBDEBUG")) {
if (mnnvl_fabric_support) {
printf("MNNVL NVLINK is supported on this platform.\n");
} else {
printf("MNNVL NVLINK is not supported on this platform.\n");
}
}
return mnnvl_fabric_support;
#endif
}
int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
int numlocal, int mynode, int numnodes,
ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes) {
*comm = new communicator();
(*comm)->comm_world = EXT_COMM_WORLD;
(*comm)->_allgather = ext_allgather;
(*comm)->_barrier = ext_barrier;
(*comm)->nranks = numranks;
(*comm)->myrank = myrank;
(*comm)->free_region = 0;
(*comm)->launch_mode = NVTE_LAUNCH_GPU | NVTE_LAUNCH_CPU;
int cur_dev, ndev;
cudaDeviceProp device_prop;
NVTE_CHECK_CUDA(cudaGetDevice(&cur_dev));
NVTE_CHECK_CUDA(cudaGetDeviceCount(&ndev));
NVTE_CHECK_CUDA(cudaGetDeviceProperties(&device_prop, cur_dev));
(*comm)->sm_arch = device_prop.major;
// (*comm)->use_rr_kernel = device_prop.major == 8;
(*comm)->use_rr_kernel = 0;
(*comm)->push = 1;
(*comm)->use_ce = 0;
(*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0;
int device_clock = 0;
// 110 sec wait time by default
int sec_timeout = getenv("UB_TIMEOUT") ? atoi(getenv("UB_TIMEOUT")) : 110;
NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&device_clock, cudaDevAttrClockRate, cur_dev));
(*comm)->ub_timeout = 1000ull * device_clock * sec_timeout;
if ((*comm)->myrank == 0) {
printf("UB_TIMEOUT is set to %d sec, %" PRIu64 " cycles, freq: %dkhz\n", sec_timeout,
(*comm)->ub_timeout, device_clock);
}
(*comm)->comm_intra = EXT_COMM_INTRA;
(*comm)->nvrank = mylocal;
(*comm)->nvsize = numlocal;
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int core;
if (mylocal == 0) core = 50;
if (mylocal == 1) core = 58;
if (mylocal == 2) core = 18;
if (mylocal == 3) core = 26;
if (mylocal == 4) core = 114;
if (mylocal == 5) core = 122;
if (mylocal == 6) core = 82;
if (mylocal == 7) core = 90;
CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) {
if (core > 128)
CPU_SET(core - 128, &cpuset);
else
CPU_SET(core + 128, &cpuset);
}
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
printf("%d: device used %d[%d] ,resetting device to %d\n", myrank, cur_dev, ndev, mylocal);
NVTE_CHECK_CUDA(cudaSetDevice(mylocal));
}
(*comm)->mydev = cur_dev;
// FIXME need to check that numlocal is multiple of pipegpus x tensorgpus
// ar1 is data
int divgpus = pipegpus * tensorgpus;
int datagpus = numlocal / divgpus;
(*comm)->ar_nvsize = datagpus;
(*comm)->ar_firstgpu = mylocal - ((mylocal / tensorgpus) % datagpus) * tensorgpus;
(*comm)->ar_nvrank = (mylocal - (*comm)->ar_firstgpu) / tensorgpus;
// ar2 is tensor
(*comm)->ar2_nvsize = tensorgpus;
(*comm)->ar2_firstgpu = mylocal - mylocal % tensorgpus;
(*comm)->ar2_nvrank = mylocal - (*comm)->ar2_firstgpu;
// ar2 has step equal to ar_nvsize
int allnodes = numranks / numlocal;
int nodeid = myrank / numlocal;
(*comm)->num_nodes = numnodes;
(*comm)->my_node = mynode;
#define NBUF 2
#if CUDART_VERSION >= 12010
bool mnnvl_fabric = has_mnnvl_fabric(cur_dev);
if (!transformer_engine::getenv<bool>("UB_SKIPMC") &&
transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) {
// multicast init only for TP ops (____2 operations)
size_t mc_maxsize = MULTICAST_GB_TOTAL * (1ull << 30);
(*comm)->mc_offset = 0;
(*comm)->use_mc = 1;
size_t gran;
CUmulticastObjectProp mcProp = {};
mcProp.numDevices = (*comm)->ar2_nvsize;
mcProp.size = (*comm)->mc_maxsize;
mcProp.handleTypes =
mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMulticastGetGranularity, &gran, &mcProp,
static_cast<CUmemAllocationGranularity_flags>(CU_MULTICAST_GRANULARITY_RECOMMENDED));
mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran;
mcProp.size = mc_maxsize;
(*comm)->mc_maxsize = mc_maxsize;
if ((*comm)->ar2_nvrank == 0)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp);
if (mnnvl_fabric) {
CUmemFabricHandle *exphndl =
reinterpret_cast<CUmemFabricHandle *>(malloc(sizeof(CUmemFabricHandle)));
CUmemFabricHandle *tmphndl =
reinterpret_cast<CUmemFabricHandle *>(malloc(sizeof(CUmemFabricHandle)));
CUmemFabricHandle *exphndls;
NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle)));
if ((*comm)->ar2_nvrank == 0)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast<void *>(tmphndl),
(*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
for (int grp = 0; grp < (*comm)->ar_nvsize;
grp++) { // we do N broadcasts for N TP groups in NVL domain
int root = grp * (*comm)->ar2_nvsize;
// It just needs to be a bcast but reuse existing allgather comm
(*comm)->_allgather(
reinterpret_cast<void *>(exphndls), (*comm)->nvsize * sizeof(CUmemFabricHandle),
reinterpret_cast<void *>(tmphndl), sizeof(CUmemFabricHandle), (*comm)->comm_intra);
//save data if brodcast was from rank 0 in our group
if ((*comm)->ar2_firstgpu == root)
memcpy(exphndl, exphndls + root, sizeof(CUmemFabricHandle));
}
if ((*comm)->ar2_nvrank != 0)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &(*comm)->mc_handle,
reinterpret_cast<void *>(exphndl), CU_MEM_HANDLE_TYPE_FABRIC);
free(exphndl);
free(tmphndl);
NVTE_CHECK_CUDA(cudaFreeHost(exphndls));
} else {
// Broadcast the a POSIX file descriptor from the local root rank to other local ranks.
// NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the
// file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use Unix domain sockets for the kernel to
// recreate the correct file descriptor on every receiving rank.
int fd;
volatile uint32_t abortFlag = 0;
IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafe0000 + (*comm)->my_node;
ipcSocketResult_t ret = ipcSocketSuccess;
IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag));
(*comm)->_barrier((*comm)->comm_world);
if ((*comm)->ar2_nvrank == 0) {
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemExportToShareableHandle, reinterpret_cast<void *>(&fd), (*comm)->mc_handle,
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error);
}
} else {
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error);
}
}
error:
if ((*comm)->ar2_nvrank != 0) {
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast<void *>(fd),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
}
IPCCHECK(ipcSocketClose(&ipcSock));
close(fd);
}
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle,
(CUdeviceptr)(*comm)->mydev);
CUdeviceptr mc_va;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &mc_va, mc_maxsize, (size_t)0, (CUdeviceptr)0U,
(uint64_t)0);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemMap, mc_va, mc_maxsize, (size_t)0, (*comm)->mc_handle,
(uint64_t)0);
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = (*comm)->mydev;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemSetAccess, mc_va, mc_maxsize,
const_cast<CUmemAccessDesc *>(&accessDesc), (size_t)1);
(*comm)->mc_baseptr = reinterpret_cast<void *>(mc_va);
(*comm)->_barrier((*comm)->comm_world);
if (!(*comm)->myrank) printf("MC initialized succesfully, window size = %ld\n", mc_maxsize);
} else {
#endif
if (!(*comm)->myrank) printf("MC NOT initialized and used\n");
(*comm)->mc_maxsize = 0;
(*comm)->mc_offset = 0;
(*comm)->use_mc = 0;
#if CUDART_VERSION >= 12010
}
#endif
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, true);
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
NVTE_CHECK_CUDA(
cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
(*comm)->sms = 16;
(*comm)->threads = 1024;
#define GPU_PAGE_SHIFT 16
#define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT)
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
NVTE_CHECK_CUDA(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
NVTE_CHECK_CUDA(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
using namespace std;
sched_param param;
pthread_attr_t attr;
pthread_attr_init(&attr);
pthread_attr_getschedparam(&attr, &param);
param.sched_priority = sched_get_priority_max(SCHED_FIFO);
pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG"))
printf(
"%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP x%d TPGROUP "
"%dx%d\n",
myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my_node, (*comm)->ar2_nvrank, (*comm)->ar_nvsize,
(*comm)->num_nodes, (*comm)->ar2_nvsize);
fflush(NULL);
return 0;
}
int create_communicator_grouped(communicator **comm, int myrank, int numranks, int mylocal,
int numlocal, int mynode, int numnodes,
ExtAllgatherOp ext_allgather, ExtBarrierOp ext_barrier,
int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, pipegpus, pipenodes, 1, 1);
}
int create_communicator(communicator **comm, int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, ExtAllgatherOp ext_allgather,
ExtBarrierOp ext_barrier) {
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_allgather, ext_barrier, 1, 1, 1, 1);
}
int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipenodes,
int tensorgpus, int tensornodes) {
#ifdef NVTE_UB_WITH_MPI
// get global numbers
int myrank, numranks;
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank));
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks));
int mylocal, numlocal;
UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, myrank / tensorgpus, myrank, &EXT_COMM_INTRA));
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal));
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal));
// find internode numbers and make internode communicator
NVTE_CHECK_CUDA(cudaFree(0));
int mynode, numnodes;
mynode = myrank / numlocal;
numnodes = numranks / numlocal;
// finally call the abstracted constructor with MPI info
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
&ub_mpi_allgather, &ub_mpi_barrier, pipegpus, pipenodes,
tensorgpus, tensornodes);
#else
NVTE_ERROR(std::string("Bootstrapping Userbuffers with MPI requires building") +
std::string("Transformer Engine with NVTE_UB_WITH_MPI=1 and MPI_HOME=/path/to/mpi"));
#endif
}
int create_communicator_grouped_mpi(communicator **comm, int pipegpus, int pipenodes) {
return create_communicator_grouped2_mpi(comm, pipegpus, pipenodes, 1, 1);
}
int create_communicator_mpi(communicator **comm) {
return create_communicator_grouped2_mpi(comm, 1, 1, 1, 1);
}
void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (comm->use_mc && comm->mem_dealloc[hndl]) {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank == comm->nvrank) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
} else {
comm->uchandles[hndl][rank] = 0;
}
}
free(reinterpret_cast<void *>(comm->uchandles[hndl]));
} else {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank != comm->nvrank) {
cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]);
} else if (comm->mem_dealloc[hndl]) {
NVTE_CHECK_CUDA(cudaFree(comm->peer_ptr[hndl][rank]));
} else {
comm->peer_ptr[hndl][rank] = nullptr; // remove reference to external buffer
}
}
}
free(comm->peer_ptr[hndl]);
comm->mem_ptr[hndl] = nullptr;
}
cudaFree(reinterpret_cast<void *>(comm->recv_id));
cudaFree(reinterpret_cast<void *>(comm->send_id));
if (comm->use_mc) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle);
}
delete comm;
}
void destroy_communicator_mpi(communicator *comm) {
#ifdef NVTE_UB_WITH_MPI
MPI_Comm_free(static_cast<MPI_Comm *>(&(comm->comm_intra)));
destroy_communicator(comm);
#else
NVTE_ERROR(std::string("Communicator is not bootstrapped with MPI and ") +
std::string("can only be deallocated with destroy_communicator()."));
#endif
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
int hndl = comm->free_region;
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
size_t aligned_size = bytes;
comm->memflags[hndl] = 0;
comm->mem_dealloc[hndl] = alloc;
#if CUDART_VERSION >= 12010
if (comm->use_mc && alloc) {
bool mnnvl_fabric = has_mnnvl_fabric(comm->mydev);
int nranks = comm->nvsize; // total GPUs in NVLINK domain
int myrank = comm->nvrank;
void **remptrs = reinterpret_cast<void **>(malloc(nranks * sizeof(void *)));
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = comm->mydev;
prop.requestedHandleTypes =
mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
size_t granularity = 0;
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemGetAllocationGranularity, &granularity, &prop,
static_cast<CUmemAllocationGranularity_flags>(CU_MULTICAST_GRANULARITY_MINIMUM));
// MPI_Allreduce MAX of granularity check
aligned_size = (bytes + granularity - 1) / granularity * granularity;
if (comm->use_mc) {
CUmulticastObjectProp mcProp = {};
mcProp.numDevices = nranks;
mcProp.size = aligned_size;
mcProp.handleTypes = prop.requestedHandleTypes;
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMulticastGetGranularity, &granularity, &mcProp,
static_cast<CUmemAllocationGranularity_flags>(CU_MULTICAST_GRANULARITY_MINIMUM));
aligned_size = (aligned_size + granularity - 1) / granularity * granularity;
}
prop.location.id = comm->mydev;
comm->uchandles[hndl] = reinterpret_cast<CUmemGenericAllocationHandle *>(
malloc(nranks * sizeof(CUmemGenericAllocationHandle)));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop,
(uint64_t)0);
if (mnnvl_fabric) {
CUmemFabricHandle *exphndl;
CUmemFabricHandle myhndl;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl,
comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0);
NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle)));
comm->_allgather(reinterpret_cast<void *>(exphndl), comm->nvsize * sizeof(CUmemFabricHandle),
reinterpret_cast<void *>(&myhndl), sizeof(CUmemFabricHandle),
comm->comm_intra);
for (int p = 0; p < nranks; p++)
if (p != myrank)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &comm->uchandles[hndl][p],
reinterpret_cast<void *>(&exphndl[p]),
CU_MEM_HANDLE_TYPE_FABRIC);
NVTE_CHECK_CUDA(cudaFreeHost(exphndl));
} else {
int *peerfd = reinterpret_cast<int *>(malloc(nranks * sizeof(int)));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemExportToShareableHandle, reinterpret_cast<void *>(&peerfd[myrank]),
comm->uchandles[hndl][myrank],
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
volatile uint32_t abortFlag = 0;
IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafe0000 + comm->my_node;
ipcSocketResult_t ret = ipcSocketSuccess;
// All-gather POSIX file descriptors across local ranks
IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag));
for (int p = 1; p < nranks; p++) {
int send_to = (myrank + p) % nranks;
int recv_from = (myrank + nranks - p) % nranks;
comm->_barrier(comm->comm_intra);
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret,
error);
IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error);
}
error:
IPCCHECK(ipcSocketClose(&ipcSock));
for (int p = 0; p < nranks; p++) {
if (p != myrank)
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &comm->uchandles[hndl][p],
reinterpret_cast<void *>(peerfd[p]),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(peerfd[p]);
}
free(peerfd);
}
CUdeviceptr ptr;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks),
(size_t)0, (CUdeviceptr)0, (uint64_t)0);
comm->ucbase_ptr[hndl] = reinterpret_cast<void *>(ptr);
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
accessDesc.location.id = comm->mydev;
for (int i = 0; i < nranks; i++) {
remptrs[i] = reinterpret_cast<void *>(ptr + (aligned_size * i));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemMap, reinterpret_cast<CUdeviceptr>(remptrs[i]), aligned_size,
(size_t)0, comm->uchandles[hndl][i], (uint64_t)0);
if (i == comm->nvrank) {
if (hndl)
*gpubuff = remptrs[i];
else
comm->gpu_ptrs = remptrs[i];
}
comm->peer_ptr[hndl][i] = remptrs[i];
}
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemSetAccess, ptr, (size_t)(aligned_size * nranks),
const_cast<CUmemAccessDesc *>(&accessDesc), (size_t)1);
if (hndl == 0) NVTE_CHECK_CUDA(cudaMemset(comm->gpu_ptrs, 0, aligned_size));
NVTE_CHECK_CUDA(
cudaMemcpy((reinterpret_cast<char *>(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)),
remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice));
free(remptrs);
comm->memflags[hndl] = NVTE_UB_MEM_UC_CONTIG | NVTE_UB_MEM_ALLOCATED;
if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset,
comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/,
aligned_size, (uint64_t)0);
comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED;
comm->mc_ptr[hndl] = reinterpret_cast<char *>(comm->mc_baseptr) + comm->mc_offset;
comm->mc_offset += aligned_size;
} else if (!comm->myrank) {
printf("UB: warning region %d size %ld MB registered without MC access\n", hndl,
aligned_size / 1024 / 1024);
}
} else {
#endif
if (alloc) {
NVTE_CHECK_CUDA(cudaMalloc(gpubuff, bytes));
NVTE_CHECK_CUDA(cudaMemset(*gpubuff, 0, bytes));
}
NVTE_CHECK(comm->nvsize <= 8, "CUDA IPC supports only up to 8 GPUs in an NVLink domain.");
cudaIpcMemHandle_t memhndl;
NVTE_CHECK_CUDA(cudaIpcGetMemHandle(&memhndl, *gpubuff));
cudaIpcMemHandle_t *tmp =
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(comm->nvsize * sizeof(cudaIpcMemHandle_t)));
comm->_allgather(reinterpret_cast<void *>(tmp), comm->nvsize * sizeof(cudaIpcMemHandle_t),
reinterpret_cast<void *>(&memhndl), sizeof(cudaIpcMemHandle_t),
comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
cudaIpcMemLazyEnablePeerAccess));
}
}
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(
reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
free(tmp);
#if CUDART_VERSION >= 12010
}
#endif
comm->mem_size[hndl] = aligned_size;
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment