Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
5b6ef054
Commit
5b6ef054
authored
Mar 17, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
76060570
a7eeb28b
Changes
225
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2053 additions
and
0 deletions
+2053
-0
transformer_engine/common/activation/swiglu.cu
transformer_engine/common/activation/swiglu.cu
+34
-0
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+1041
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc
..._engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc
+262
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h
...r_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h
+52
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
...common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
+664
-0
No files found.
Too many changes to show.
To preserve performance only
225 of 225+
files are displayed.
Plain diff
Email patch
transformer_engine/common/activation/swiglu.cu
0 → 100644
View file @
5b6ef054
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../util/math.h"
#include "./activation_template.h"
void
nvte_silu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_silu
);
using
namespace
transformer_engine
;
act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_dsilu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dsilu
);
using
namespace
transformer_engine
;
dact_fn
<
fp32
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
void
nvte_swiglu
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_swiglu
);
using
namespace
transformer_engine
;
gated_act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>>
(
input
,
output
,
stream
);
}
void
nvte_dswiglu
(
const
NVTETensor
grad
,
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_dswiglu
);
using
namespace
transformer_engine
;
dgated_act_fn
<
fp32
,
Empty
,
silu
<
fp32
,
fp32
>
,
dsilu
<
fp32
,
fp32
>>
(
grad
,
input
,
output
,
stream
);
}
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
0 → 100644
View file @
5b6ef054
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
#include <numeric>
#include "common/common.h"
#include "common/util/cuda_driver.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define AS_VECTOR(shape) std::vector<size_t>(shape.data, shape.data + shape.ndim)
using
namespace
std
::
placeholders
;
namespace
transformer_engine
{
/***************************************************************************************************
* Comm+GEMM Overlap Common Core
**************************************************************************************************/
bool
ubuf_built_with_mpi
()
{
#ifdef NVTE_UB_WITH_MPI
return
true
;
#else
return
false
;
#endif
}
CommOverlapCore
::
CommOverlapCore
(
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
int
tp_size
,
ExtAllgatherOp
allgather_handle
,
ExtBarrierOp
barrier_handle
,
int
num_splits
,
int
num_max_streams
,
int
comm_cga_size
,
int
gemm_priority
,
int
comm_priority
,
int
num_comm_sm
,
bool
set_sm_margin
,
bool
use_ce
,
bool
atomic_gemm
)
{
// Initialize userbuf communicator
if
(
!
_comm_created
)
{
if
(
myrank
==
0
)
{
printf
(
"!!! [UB] Create Userbuffers Communicator
\n
"
);
}
#ifdef NVTE_UB_WITH_MPI
create_communicator_grouped2_mpi
(
&
_ub_comm
,
1
,
1
,
tp_size
,
1
);
#else
create_communicator_grouped2
(
&
_ub_comm
,
myrank
,
numranks
,
mylocal
,
numlocal
,
mynode
,
numnodes
,
allgather_handle
,
barrier_handle
,
1
,
1
,
tp_size
,
1
);
#endif
_comm_created
=
true
;
}
_use_ce
=
static_cast
<
int
>
(
use_ce
);
_num_comm_sm
=
num_comm_sm
;
_cga_size
=
comm_cga_size
;
if
(
gemm_priority
==
0
&&
comm_priority
==
0
)
{
transformer_engine
::
cuda
::
stream_priority_range
(
&
_gemm_priority
,
&
_comm_priority
);
}
else
{
_gemm_priority
=
gemm_priority
;
_comm_priority
=
comm_priority
;
}
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
cudaStream_t
stream
;
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
stream
,
cudaStreamNonBlocking
,
_gemm_priority
));
_stream_compute
.
push_back
(
std
::
move
(
stream
));
}
_num_splits
=
num_splits
;
_rank
=
_ub_comm
->
myrank
;
_tp_size
=
tp_size
;
_tp_id
=
_rank
%
_tp_size
;
// Set the number of SMs for GEMM with margin
int
sm_count
=
transformer_engine
::
cuda
::
sm_count
();
_math_sms
=
(
set_sm_margin
)
?
sm_count
-
num_comm_sm
:
sm_count
;
_math_sms
-=
transformer_engine
::
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
0
);
_atomic_gemm
=
atomic_gemm
;
if
(
_atomic_gemm
)
{
void
*
counter_ptr
;
size_t
counter_bytes
=
_num_splits
*
2
*
sizeof
(
int32_t
);
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
counter_ptr
,
counter_bytes
));
NVTE_CHECK_CUDA
(
cudaMemset
(
counter_ptr
,
0
,
counter_bytes
));
NVTE_CHECK_CUDA
(
cudaMemset
(
counter_ptr
,
1
,
counter_bytes
/
2
));
_counter
=
TensorWrapper
(
counter_ptr
,
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
_num_splits
*
2
)},
DType
::
kInt32
);
}
// CUDA event creation
cudaEventCreateWithFlags
(
&
_start_compute
,
0
);
cudaEventCreateWithFlags
(
&
_stop_compute
,
0
);
cudaEventCreateWithFlags
(
&
_start_comm
,
0
);
cudaEventCreateWithFlags
(
&
_stop_comm
,
0
);
/*
Defining the launcher order between the communication and GEMM kernels
using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1.
The event is used to schedule the communication kernel before the GEMM.
This is needed only for Hopper, which uses persistent CTA execution.
*/
int
max_connection
=
transformer_engine
::
getenv
<
int
>
(
"CUDA_DEVICE_MAX_CONNECTIONS"
,
8
);
int
runtime_version
=
0
;
cudaRuntimeGetVersion
(
&
runtime_version
);
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
if
(
runtime_version
>=
12030
&&
deviceProp
.
major
==
9
&&
max_connection
>
1
)
{
cudaEventCreateWithFlags
(
&
_comm_launch_event
,
cudaEventDisableTiming
);
}
else
{
_comm_launch_event
=
0
;
}
}
CommOverlapCore
::~
CommOverlapCore
()
{
cudaEventDestroy
(
_stop_comm
);
cudaEventDestroy
(
_start_comm
);
cudaEventDestroy
(
_stop_compute
);
cudaEventDestroy
(
_start_compute
);
if
(
_comm_launch_event
)
cudaEventDestroy
(
_comm_launch_event
);
if
(
_atomic_gemm
)
cudaFree
(
_counter
.
dptr
());
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
cudaStreamDestroy
(
_stream_compute
[
i
]);
if
(
_comm_created
)
{
#ifdef NVTE_UB_WITH_MPI
destroy_communicator_mpi
(
_ub_comm
);
#else
destroy_communicator
(
_ub_comm
);
#endif
_comm_created
=
false
;
}
}
TensorWrapper
CommOverlapCore
::
get_tensor_chunk
(
const
TensorWrapper
&
source
,
size_t
chunk_offset
,
const
std
::
vector
<
size_t
>
&
chunk_shape
)
{
TensorWrapper
chunk
;
for
(
int
param_id
=
0
;
param_id
<
NVTETensorParam
::
kNVTENumTensorParams
;
param_id
++
)
{
auto
param_type
=
static_cast
<
NVTETensorParam
>
(
param_id
);
auto
param
=
source
.
get_parameter
(
param_type
);
auto
param_dptr
=
reinterpret_cast
<
char
*>
(
param
.
data_ptr
);
auto
param_dtype
=
static_cast
<
DType
>
(
param
.
dtype
);
auto
param_shape
=
AS_VECTOR
(
param
.
shape
);
if
(
param_dptr
!=
nullptr
)
{
if
(
param_type
==
NVTETensorParam
::
kNVTERowwiseData
||
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
)
{
// Offset data pointer
param_dptr
+=
chunk_offset
*
typeToSize
(
param_dtype
);
param_shape
=
chunk_shape
;
if
(
param_type
==
NVTETensorParam
::
kNVTEColumnwiseData
&&
source
.
scaling_mode
()
!=
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
)
{
// Columnwise shape for non-block scaled tensors shifts the last dimension to the front
auto
last_dim
=
param_shape
.
back
();
param_shape
.
pop_back
();
param_shape
.
insert
(
param_shape
.
begin
(),
last_dim
);
}
}
else
if
(
source
.
scaling_mode
()
==
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
&&
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
||
param_type
==
NVTETensorParam
::
kNVTEColumnwiseScaleInv
))
{
// Calculate block scaling offset and size
auto
scaled_tensor_dim_size
=
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
?
source
.
shape
().
data
[
0
]
:
source
.
columnwise_shape
().
data
[
0
];
auto
scaled_chunk_dim_size
=
(
param_type
==
NVTETensorParam
::
kNVTERowwiseScaleInv
)
?
chunk_shape
.
front
()
:
chunk_shape
.
back
();
auto
chunk_scale_start
=
chunk_offset
/
32
;
auto
chunk_scale_end
=
(
chunk_offset
+
scaled_chunk_dim_size
)
/
32
;
auto
chunk_scale_size
=
chunk_scale_end
-
chunk_scale_start
;
param_dptr
+=
chunk_scale_start
*
typeToSize
(
param_dtype
);
param_shape
=
std
::
vector
<
size_t
>
{
chunk_scale_size
};
}
// Set chunked source parameters into the chunked tensor output
chunk
.
set_parameter
(
param_type
,
reinterpret_cast
<
void
*>
(
param_dptr
),
param_dtype
,
param_shape
);
}
}
return
chunk
;
}
TensorWrapper
CommOverlapCore
::
get_buffer_chunk_like
(
const
TensorWrapper
&
source
,
size_t
chunk_offset
,
const
std
::
vector
<
size_t
>
&
chunk_shape
)
{
// Start with a chunk of the source tensor
auto
chunk
=
get_tensor_chunk
(
source
,
chunk_offset
,
chunk_shape
);
// Update chunk with offset data pointers from the communication buffer
auto
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
())
+
(
chunk_offset
*
_ubuf
.
element_size
());
if
(
chunk
.
dptr
()
!=
nullptr
)
{
chunk
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
ubuf_ptr
),
chunk
.
dtype
(),
chunk
.
shape
());
}
if
(
chunk
.
columnwise_dptr
()
!=
nullptr
)
{
chunk
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
ubuf_ptr
),
chunk
.
dtype
(),
chunk
.
columnwise_shape
());
}
return
chunk
;
}
/***************************************************************************************************
* Comm+GEMM Overlap Base (Pipelined / Collective)
**************************************************************************************************/
CommOverlapBase
::
CommOverlapBase
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
int
tp_size
,
ExtAllgatherOp
allgather_handle
,
ExtBarrierOp
barrier_handle
,
int
num_splits
,
int
num_max_streams
,
int
comm_cga_size
,
int
gemm_priority
,
int
comm_priority
,
int
num_comm_sm
,
bool
set_sm_margin
,
bool
atomic_gemm
,
bool
rs_overlap_first_gemm
)
:
CommOverlapCore
(
myrank
,
numranks
,
mylocal
,
numlocal
,
mynode
,
numnodes
,
tp_size
,
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
atomic_gemm
)
{
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
NVTE_CHECK
(
_rs_kernel_type
>=
0
&&
_rs_kernel_type
<=
3
,
"Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) "
,
"or 2 (multi-atomic)."
);
NVTE_CHECK
(
buffer_shape
.
size
()
==
2
,
"Userbuffer shape must be 2-dimensional!"
);
size_t
buffer_bytes
=
buffer_shape
[
0
]
*
buffer_shape
[
1
]
*
typeToSize
(
buffer_dtype
);
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
_stream_comm
,
cudaStreamNonBlocking
,
_comm_priority
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
}
CommOverlapBase
::~
CommOverlapBase
()
{
cudaEventDestroy
(
_start_d2dcopy
);
cudaStreamDestroy
(
_stream_comm
);
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
void
CommOverlapBase
::
bulk_overlap
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
CommOverlapType
comm_type
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
0
],
_start_comm
,
0
));
// Communication: AG and RS
int
comm_elements
=
(
_ubuf
.
numel
()
/
2
)
*
_ubuf
.
element_size
();
// UBUF uses 2Byte element size
if
(
comm_type
==
CommOverlapType
::
AG
)
{
allgather2_userbuff_inplace
(
_ub_reg
,
0
,
comm_elements
,
_ub_comm
,
_stream_comm
,
(
cudaEvent_t
)
_comm_launch_event
);
}
else
{
if
(
_ubuf
.
element_size
()
==
1
)
{
assert
(
_ubuf_scale_inv_initialized
);
comm_elements
*=
2
;
assert
(
rs_output
.
numel
()
==
_ubuf
.
numel
()
/
_tp_size
);
assert
(
rs_output
.
size
(
0
)
==
_ubuf
.
size
(
0
)
/
_tp_size
);
assert
(
rs_output
.
element_size
()
==
2
);
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
reducescatter2_userbuff_fp8
<
__nv_fp8_e5m2
>
(
rs_output_ptr
,
_ubuf
.
scale_inv
(),
_ub_reg
,
0
,
comm_elements
,
_ub_comm
,
_stream_comm
,
(
cudaEvent_t
)
_comm_launch_event
);
}
else
{
reducescatter2_userbuff_inplace
(
_ub_reg
,
0
,
comm_elements
,
_ub_comm
,
_stream_comm
,
(
cudaEvent_t
)
_comm_launch_event
);
}
}
assert
(
pre_gelu_out
.
numel
()
==
0
);
// When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch
if
(
_comm_launch_event
)
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
((
cudaStream_t
)
_stream_compute
[
0
],
_comm_launch_event
,
0
));
nvte_cublas_gemm
(
A
.
data
(),
B
.
data
(),
D
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
_ub_comm
->
sms
=
ori_sms
;
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
// CommOverlapBase::bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void
CommOverlapBase
::
atomic_gemm_overlap_rs
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
// Get GEMM dimensions
size_t
m
=
transa
?
A
.
size
(
0
)
:
A
.
size
(
1
);
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
n
=
_ubuf
.
size
(
0
);
size_t
m_chunk
=
m
/
_num_splits
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
// Get input, output, and workspace data pointers
char
*
input_a_chunk_ptr
=
reinterpret_cast
<
char
*>
(
A
.
dptr
());
char
*
output_buf_chunk_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
());
char
*
workspace_ptr
=
reinterpret_cast
<
char
*>
(
workspace
.
dptr
());
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
// Reset atomic counters
int
*
counter_ptr
=
reinterpret_cast
<
int
*>
(
_counter
.
dptr
());
reset_counters
(
counter_ptr
,
_num_splits
,
false
,
stream_main
);
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_compute
,
0
));
assert
(
pre_gelu_out
.
numel
()
==
0
);
auto
output_d
=
get_buffer_chunk_like
(
D
,
0
,
{
n
,
m
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
nvte_cublas_atomic_gemm
(
A
.
data
(),
B
.
data
(),
output_d
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_num_splits
,
0
,
true
,
_counter
.
data
(),
_stream_compute
[
0
]);
for
(
int
i
=
0
;
i
<
_num_splits
;
i
++
)
{
if
(
_rs_kernel_type
==
1
)
{
if
(
i
==
_num_splits
-
1
)
{
_ub_comm
->
sms
=
UB_MAX_SM
;
}
if
(
_ubuf
.
element_size
()
==
1
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reducescatter2_userbuff_strided_atomic_fp8
<
fp8_type
>
(
rs_output_ptr
,
D
.
scale_inv
(),
_ub_reg
,
i
*
m_chunk
,
m_chunk
,
n
,
m
,
m
,
_num_splits
,
&
counter_ptr
[
i
],
_ub_comm
,
_stream_comm
););
}
else
{
reducescatter2_userbuff_strided_atomic
(
rs_output_ptr
,
_ub_reg
,
i
*
m_chunk
,
m_chunk
,
n
,
m
,
_num_splits
,
&
counter_ptr
[
i
],
_ub_comm
,
_stream_comm
);
}
}
else
if
(
_rs_kernel_type
==
2
)
{
if
(
_ubuf
.
element_size
()
==
1
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reducescatter2_userbuff_strided_multiatomic_fp8
<
fp8_type
>
(
rs_output_ptr
,
D
.
scale_inv
(),
_ub_reg
,
m_chunk
,
m_chunk
,
n
,
m
,
m
,
_num_splits
,
counter_ptr
,
_ub_comm
,
_stream_comm
););
}
else
{
reducescatter2_userbuff_strided_multiatomic
(
rs_output_ptr
,
_ub_reg
,
m_chunk
,
m_chunk
,
n
,
m
,
_num_splits
,
counter_ptr
,
_ub_comm
,
_stream_comm
);
}
break
;
}
else
{
consumer
(
counter_ptr
,
i
,
_stream_comm
);
if
(
_ubuf
.
element_size
()
==
1
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reducescatter2_userbuff_stridedoutput_fp8
<
fp8_type
>
(
rs_output_ptr
,
D
.
scale_inv
(),
_ub_reg
,
i
*
m_chunk
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
););
}
else
{
reducescatter2_userbuff_strided
(
rs_output_ptr
,
_ub_reg
,
i
*
m_chunk
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
);
}
}
rs_output_ptr
+=
m_chunk
*
rs_output
.
element_size
();
}
_ub_comm
->
sms
=
ori_sms
;
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_compute
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
// split_overlap_rs
/*
** Split FPROP GEMM + ReduceScatter
*/
void
CommOverlapBase
::
split_overlap_rs
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
// Get GEMM dimensions
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
size_t
m
=
transa
?
A
.
size
(
0
)
:
A
.
size
(
1
);
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
n
=
_ubuf
.
size
(
0
);
size_t
m_chunk
=
m
/
_num_splits
;
size_t
input_a_chunk_size
=
m_chunk
*
k
;
size_t
output_chunk_size
=
n
*
m_chunk
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
i
],
_start_compute
,
0
));
}
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_compute
,
0
));
assert
(
pre_gelu_out
.
numel
()
==
0
);
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
if
(
_rs_overlap_first_gemm
)
{
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
0
,
{
m_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
{
m
,
m_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
{
n
,
m_chunk
});
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
// Communication chunk
if
(
_ubuf
.
element_size
()
==
1
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reducescatter2_userbuff_stridedoutput_fp8
<
fp8_type
>
(
rs_output_ptr
,
D
.
scale_inv
(),
_ub_reg
,
(
i
-
1
)
*
output_chunk_size
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
););
}
else
{
reducescatter2_userbuff_stridedoutput
(
rs_output_ptr
,
_ub_reg
,
(
i
-
1
)
*
output_chunk_size
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
);
}
rs_output_ptr
+=
m_chunk
*
rs_output
.
element_size
();
}
int
last_compute_stream_id
=
(
_num_splits
+
_stream_compute
.
size
()
-
1
)
%
_stream_compute
.
size
();
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
last_compute_stream_id
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
// Last communication chunk with max SM
_ub_comm
->
sms
=
UB_MAX_SM
;
if
(
_ubuf
.
element_size
()
==
1
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reducescatter2_userbuff_stridedoutput_fp8
<
fp8_type
>
(
rs_output_ptr
,
D
.
scale_inv
(),
_ub_reg
,
(
_num_splits
-
1
)
*
output_chunk_size
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
););
}
else
{
reducescatter2_userbuff_stridedoutput
(
rs_output_ptr
,
_ub_reg
,
(
_num_splits
-
1
)
*
output_chunk_size
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
);
}
}
else
{
for
(
int
i
=
0
;
i
<
_num_splits
;
i
++
)
{
auto
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
i
*
output_chunk_size
,
{
n
,
m_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
// Communication chunk. Uses MAX_SM at the last chunk
if
(
i
==
_num_splits
-
1
)
{
_ub_comm
->
sms
=
UB_MAX_SM
;
}
if
(
_ubuf
.
element_size
()
==
1
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reducescatter2_userbuff_stridedoutput_fp8
<
fp8_type
>
(
rs_output_ptr
,
D
.
scale_inv
(),
_ub_reg
,
i
*
output_chunk_size
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
););
}
else
{
reducescatter2_userbuff_stridedoutput
(
rs_output_ptr
,
_ub_reg
,
i
*
output_chunk_size
,
m_chunk
,
n
,
m
,
_ub_comm
,
_stream_comm
);
}
rs_output_ptr
+=
m_chunk
*
rs_output
.
element_size
();
}
}
_ub_comm
->
sms
=
ori_sms
;
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_compute
,
_stream_compute
[
i
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_compute
,
0
));
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
}
// CommOverlapBase::split_overlap_rs
/***************************************************************************************************
* Comm+GEMM Overlap P2P Base (Ring-Exchange)
**************************************************************************************************/
CommOverlapP2PBase
::
CommOverlapP2PBase
(
const
std
::
vector
<
size_t
>
&
buffer_shape
,
DType
buffer_dtype
,
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
int
tp_size
,
ExtAllgatherOp
allgather_handle
,
ExtBarrierOp
barrier_handle
,
CommOverlapType
comm_type
,
int
num_max_streams
,
int
comm_cga_size
,
int
gemm_priority
,
int
comm_priority
,
int
num_comm_sm
,
bool
set_sm_margin
,
bool
use_ce
,
bool
atomic_gemm
,
bool
aggregate
)
:
CommOverlapCore
(
myrank
,
numranks
,
mylocal
,
numlocal
,
mynode
,
numnodes
,
tp_size
,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
)
{
_is_p2p
=
true
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_aggregate
=
aggregate
;
// Create workspace tensor with userbuffer
NVTE_CHECK
(
buffer_shape
.
size
()
==
2
,
"Userbuffer shape must be 2-dimensional!"
);
size_t
buffer_bytes
=
buffer_shape
[
0
]
*
buffer_shape
[
1
]
*
typeToSize
(
buffer_dtype
);
int
buffer_chunk_bytes
=
buffer_bytes
/
tp_size
;
_num_ubuf_chunks
=
tp_size
;
if
(
_is_reduce_scatter
)
{
// GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk
// outputs for reduction at the end of the pipelining.
buffer_bytes
=
buffer_bytes
/
tp_size
*
(
tp_size
*
2
-
1
);
_num_ubuf_chunks
=
tp_size
*
2
-
1
;
}
void
*
buffer_ptr
;
_ub_reg
=
register_user_buffer_collective
(
&
buffer_ptr
,
buffer_bytes
,
_ub_comm
,
true
);
if
(
_rank
==
0
)
printf
(
"!!! [UBP2P] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
{
buffer_shape
[
0
]
/
tp_size
*
_num_ubuf_chunks
,
buffer_shape
[
1
]},
buffer_dtype
);
// Create tensor chunks for easy management
char
*
ubuf_byte_ptr
=
reinterpret_cast
<
char
*>
(
buffer_ptr
);
for
(
int
i
=
0
;
i
<
_num_ubuf_chunks
;
i
++
)
{
_ubufs
.
push_back
(
TensorWrapper
(
reinterpret_cast
<
void
*>
(
ubuf_byte_ptr
),
{
buffer_shape
[
0
]
/
tp_size
,
buffer_shape
[
1
]},
buffer_dtype
));
ubuf_byte_ptr
+=
buffer_chunk_bytes
;
}
_rank_round_tp
=
(
_rank
/
_tp_size
)
*
_tp_size
;
_next_rank
=
(
_tp_size
+
_rank
+
1
)
%
_tp_size
+
_rank_round_tp
;
_prev_rank
=
(
_tp_size
+
_rank
+
-
1
)
%
_tp_size
+
_rank_round_tp
;
_self_chunk_id
=
_tp_id
;
if
(
_atomic_gemm
&&
!
_is_reduce_scatter
)
{
_use_multiatomic_ag
=
getenv
<
bool
>
(
"NVTE_AG_P2P_MULTI_ATOMIC"
);
if
(
_use_multiatomic_ag
)
{
_use_ce
=
0
;
_ub_comm
->
push
=
1
;
if
(
_rank
==
0
)
{
printf
(
"!!userbuffers_sendrecv_multi_atomic_shuffle
\n
"
);
}
}
_self_chunk_id
=
0
;
NVTE_CHECK_CUDA
(
cudaMemset
(
_counter
.
dptr
(),
0
,
sizeof
(
int32_t
)));
}
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
_tp_size
);
i
++
)
{
cudaStream_t
stream
;
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
stream
,
cudaStreamNonBlocking
,
_comm_priority
));
_stream_send
.
push_back
(
std
::
move
(
stream
));
}
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
_stream_recv
,
cudaStreamNonBlocking
,
_comm_priority
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_recv
,
0
));
}
CommOverlapP2PBase
::~
CommOverlapP2PBase
()
{
cudaEventDestroy
(
_stop_recv
);
cudaEventDestroy
(
_stop_send
);
cudaStreamDestroy
(
_stream_recv
);
for
(
size_t
i
=
0
;
i
<
_stream_send
.
size
();
i
++
)
cudaStreamDestroy
(
_stream_send
[
i
]);
}
TensorWrapper
CommOverlapP2PBase
::
get_buffer_chunk_by_id
(
const
TensorWrapper
&
source
,
size_t
chunk_id
)
{
// Start with a chunk of the source tensor
auto
chunk
=
get_tensor_chunk
(
source
,
0
,
AS_VECTOR
(
_ubufs
[
chunk_id
].
shape
()));
// Update chunk with offset data pointers from the communication buffer
if
(
chunk
.
dptr
()
!=
nullptr
)
{
chunk
.
set_rowwise_data
(
_ubufs
[
chunk_id
].
dptr
(),
chunk
.
dtype
(),
chunk
.
shape
());
}
if
(
chunk
.
columnwise_dptr
()
!=
nullptr
)
{
chunk
.
set_columnwise_data
(
_ubufs
[
chunk_id
].
dptr
(),
chunk
.
dtype
(),
chunk
.
columnwise_shape
());
}
return
chunk
;
}
/*
** Split AllGather + AtomicGEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void
CommOverlapP2PBase
::
atomic_gemm_overlap_ag
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
B_copy
,
cudaStream_t
stream_main
)
{
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
// Get GEMM dimensions between TN and NN input layouts
const
size_t
m
=
(
transa
)
?
A
.
size
(
0
)
:
A
.
size
(
1
);
const
size_t
n_chunk
=
_ubufs
[
0
].
size
(
0
);
assert
(
pre_gelu_out
.
numel
()
==
0
);
// Get communication and GEMM output chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
// Create an GEMM output buffer with N+1 chunks in a contiguous memory
void
*
D_buffer_ptr
;
int
D_chunk_bytes
=
n_chunk
*
m
*
D
.
element_size
();
NVTE_CHECK_CUDA
(
cudaMallocAsync
(
&
D_buffer_ptr
,
(
_tp_size
+
1
)
*
D_chunk_bytes
,
stream_main
));
auto
D_buffer
=
TensorWrapper
(
D_buffer_ptr
,
D
.
shape
(),
D
.
dtype
(),
D
.
amax
(),
D
.
scale
(),
D
.
scale_inv
(),
D
.
scale_inv_shape
(),
D
.
scaling_mode
());
// Reset atomic counters
int
*
counter_ptr
=
reinterpret_cast
<
int
*>
(
_counter
.
dptr
());
reset_counters
(
counter_ptr
,
_tp_size
,
true
,
stream_main
);
// Catch up the default torch stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
auto
input_b
=
get_buffer_chunk_like
(
B
,
0
,
AS_VECTOR
(
B
.
shape
()));
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
for
(
int
i
=
0
;
i
<
_tp_size
-
1
;
i
++
)
{
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int
send_chunk_id
=
i
;
int
recv_chunk_id
=
i
+
1
;
int
send_offset
=
comm_bytes
*
send_chunk_id
;
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
if
(
_use_multiatomic_ag
)
{
if
(
i
==
0
)
{
_ub_comm
->
use_ce
=
0
;
userbuffers_sendrecv_multiatomic
(
_ub_reg
,
_ub_reg
,
comm_bytes
,
comm_bytes
,
comm_bytes
,
_ub_comm
,
_next_rank
,
_prev_rank
,
_tp_size
,
counter_ptr
,
true
,
_stream_recv
);
}
}
else
{
userbuffers_send
(
_ub_reg
,
send_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
_next_rank
,
_stream_recv
);
userbuffers_recv
(
_ub_reg
,
send_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
_prev_rank
,
_stream_recv
);
producer
(
counter_ptr
,
recv_chunk_id
,
_stream_recv
);
}
if
(
i
==
0
)
{
nvte_cublas_atomic_gemm
(
A
.
data
(),
input_b
.
data
(),
D_buffer
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
0
,
_tp_size
,
false
,
_counter
.
data
(),
stream_main
);
}
}
// Store the input activation for backprop
if
(
B_copy
.
numel
()
>
0
)
{
assert
(
B_copy
.
numel
()
==
_ubufs
[
_self_chunk_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_self_chunk_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_self_chunk_id
].
dptr
(),
_ubufs
[
_self_chunk_id
].
numel
()
*
_ubufs
[
_self_chunk_id
].
element_size
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_send
,
_stream_send
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
}
// Copy the first GEMM output chunk to the end chunk position of D_buffer
char
*
src_ptr
=
reinterpret_cast
<
char
*>
(
D_buffer
.
dptr
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
src_ptr
+
(
D
.
numel
()
*
D
.
element_size
()),
src_ptr
,
D_chunk_bytes
,
cudaMemcpyDeviceToDevice
,
stream_main
));
// Return the last N rows of D_buffer
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
D
.
dptr
(),
src_ptr
+
D_chunk_bytes
,
D
.
numel
()
*
D
.
element_size
(),
cudaMemcpyDeviceToDevice
,
stream_main
));
// Clean up buffer allocation
NVTE_CHECK_CUDA
(
cudaFreeAsync
(
D_buffer_ptr
,
stream_main
));
_ub_comm
->
sms
=
ori_sms
;
}
// CommOverlapP2PBase::atomic_gemm_overlap_ag
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
** outputs in each rank to be in the contiguous memory space after all ring exchange phases.
*/
void
CommOverlapP2PBase
::
split_overlap_ag
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
B_copy
,
cudaStream_t
stream_main
)
{
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
// Get GEMM dimensions between TN and NN input layouts
const
size_t
m
=
(
transa
)
?
A
.
size
(
0
)
:
A
.
size
(
1
);
const
size_t
k
=
(
transa
)
?
A
.
size
(
1
)
:
A
.
size
(
0
);
const
size_t
n_chunk
=
_ubufs
[
0
].
size
(
0
);
// Get communication and GEMM output chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
const
bool
do_gelu
=
pre_gelu_out
.
numel
()
>
0
;
size_t
input_chunk_size
=
n_chunk
*
k
;
size_t
output_chunk_size
=
n_chunk
*
m
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_start_compute
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
i
],
_start_compute
,
0
));
}
if
(
_aggregate
)
{
const
int
num_steps
=
_tp_size
/
2
;
input_chunk_size
*=
2
;
output_chunk_size
*=
2
;
// Initial 1X input chunk exchange between neighboring peers
int
send_chunk_id
=
_tp_id
;
int
recv_chunk_id
=
(
_tp_id
%
2
==
0
)
?
_tp_id
+
1
:
_tp_id
-
1
;
int
send_offset
=
comm_bytes
*
send_chunk_id
;
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
int
peer_rank
=
(
_tp_id
%
2
==
0
)
?
_next_rank
:
_prev_rank
;
userbuffers_send
(
_ub_reg
,
send_offset
,
_ub_reg
,
send_offset
,
comm_bytes
,
_ub_comm
,
peer_rank
,
_stream_send
[
0
]);
userbuffers_recv
(
_ub_reg
,
recv_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
peer_rank
,
_stream_recv
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
0
],
_stop_recv
,
0
));
int
local_rank_round2
=
(
_tp_id
%
2
==
0
)
?
_tp_id
:
_tp_id
-
1
;
const
int
next_rank
=
(
_tp_size
+
_tp_id
+
2
)
%
_tp_size
+
_rank_round_tp
;
const
int
prev_rank
=
(
_tp_size
+
_tp_id
-
2
)
%
_tp_size
+
_rank_round_tp
;
// Ring exchange of 2X inputs chunks
for
(
int
i
=
0
;
i
<
num_steps
;
i
++
)
{
send_chunk_id
=
(
_tp_size
+
local_rank_round2
-
i
*
2
)
%
_tp_size
;
recv_chunk_id
=
(
_tp_size
+
local_rank_round2
-
i
*
2
-
2
)
%
_tp_size
;
send_offset
=
comm_bytes
*
send_chunk_id
;
recv_offset
=
comm_bytes
*
recv_chunk_id
;
// GEMM
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
});
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
m
});
auto
aux_chunk
=
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
*
2
,
k
})
:
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
pre_gelu_out
.
dtype
());
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
if
(
i
<
num_steps
-
1
)
{
// P2P communication
userbuffers_send
(
_ub_reg
,
send_offset
,
_ub_reg
,
send_offset
,
comm_bytes
*
2
,
_ub_comm
,
next_rank
,
_stream_send
[
0
]);
userbuffers_recv
(
_ub_reg
,
recv_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
*
2
,
_ub_comm
,
prev_rank
,
_stream_recv
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[(
i
+
1
)
%
_stream_compute
.
size
()],
_stop_recv
,
0
));
}
else
if
(
B_copy
.
numel
()
>
0
)
{
assert
(
B_copy
.
numel
()
==
_ubufs
[
_tp_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_tp_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_tp_id
].
dptr
(),
_ubufs
[
_tp_id
].
numel
()
*
_ubufs
[
_tp_id
].
element_size
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
}
}
else
{
for
(
int
i
=
0
;
i
<
_tp_size
;
i
++
)
{
// Set the userbuffer id. Buffer under send is the input for the current
// GEMM chunk The initial input chunk is stored _ubuf[rank]. This is to
// have the AG output in all ranks to be contiguous after the ring
// exchanges
int
send_chunk_id
=
(
_tp_size
+
_tp_id
-
i
)
%
_tp_size
;
int
recv_chunk_id
=
(
_tp_size
+
_tp_id
-
i
-
1
)
%
_tp_size
;
int
send_offset
=
comm_bytes
*
send_chunk_id
;
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
// GEMM
auto
input_b_chunk
=
get_buffer_chunk_like
(
B
,
input_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
});
auto
output_chunk
=
get_tensor_chunk
(
D
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
m
});
auto
aux_chunk
=
(
do_gelu
)
?
get_tensor_chunk
(
pre_gelu_out
,
output_chunk_size
*
send_chunk_id
,
{
n_chunk
,
k
})
:
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
pre_gelu_out
.
dtype
());
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
if
(
i
<
_tp_size
-
1
)
{
// P2P communication
userbuffers_send
(
_ub_reg
,
send_offset
,
_ub_reg
,
send_offset
,
comm_bytes
,
_ub_comm
,
_next_rank
,
_stream_send
[
0
]);
userbuffers_recv
(
_ub_reg
,
recv_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
_prev_rank
,
_stream_recv
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
0
],
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[(
i
+
1
)
%
_stream_compute
.
size
()],
_stop_recv
,
0
));
}
else
if
(
B_copy
.
numel
()
>
0
)
{
assert
(
B_copy
.
numel
()
==
_ubufs
[
_tp_id
].
numel
());
assert
(
B_copy
.
element_size
()
==
_ubufs
[
_tp_id
].
element_size
());
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
B_copy
.
dptr
(),
_ubufs
[
_tp_id
].
dptr
(),
_ubufs
[
_tp_id
].
numel
()
*
_ubufs
[
_tp_id
].
element_size
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
}
}
_ub_comm
->
sms
=
ori_sms
;
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_compute
,
_stream_compute
[
i
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_compute
,
0
));
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_send
,
_stream_send
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
}
// CommOverlapP2PBase::split_overlap_ag
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void
CommOverlapP2PBase
::
atomic_gemm_overlap_rs
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
// Get communication and GEMM input chunk sizes
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
// Reset counters
int
*
counter_ptr
=
reinterpret_cast
<
int
*>
(
_counter
.
dptr
());
reset_counters
(
counter_ptr
,
_tp_size
,
false
,
stream_main
);
// Catch up the main stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
// Atomic GEMM
// Process GEMM chunks in the order that AG+GEMM places the output chunks.
auto
output_d
=
get_buffer_chunk_like
(
D
,
0
,
AS_VECTOR
(
D
.
shape
()));
nvte_cublas_atomic_gemm
(
A
.
data
(),
B
.
data
(),
output_d
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
0
,
_tp_size
,
true
,
_counter
.
data
(),
stream_main
);
// P2P communication chunk
for
(
int
i
=
1
;
i
<
_tp_size
;
i
++
)
{
int
send_chunk_id
=
i
-
1
;
int
recv_chunk_id
=
send_chunk_id
+
_tp_size
;
int
send_offset
=
comm_bytes
*
send_chunk_id
;
int
recv_offset
=
comm_bytes
*
recv_chunk_id
;
int
send_rank
=
(
_tp_size
+
_tp_id
-
i
)
%
_tp_size
+
_rank_round_tp
;
int
recv_rank
=
(
_tp_id
+
i
)
%
_tp_size
+
_rank_round_tp
;
consumer
(
counter_ptr
,
send_chunk_id
,
_stream_recv
);
userbuffers_send
(
_ub_reg
,
send_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
send_rank
,
_stream_recv
);
userbuffers_recv
(
_ub_reg
,
send_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
recv_rank
,
_stream_recv
);
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
// Reduce GEMM output chunks
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
if
(
_ubuf
.
element_size
()
==
1
&&
rs_output
.
element_size
()
==
2
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reduce_fp8_in_bf16_out
<
fp8_type
>
(
reduce_buf_ptr
,
rs_output_ptr
,
D
.
scale_inv
(),
_tp_size
,
_ubufs
[
0
].
numel
(),
stream_main
););
}
else
{
reduce_bf16
(
reduce_buf_ptr
,
rs_output_ptr
,
_tp_size
,
_ubufs
[
0
].
numel
(),
stream_main
);
}
_ub_comm
->
sms
=
ori_sms
;
}
/*
** Split ReduceScatter + GEMM using P2P communication
*/
void
CommOverlapP2PBase
::
split_overlap_rs
(
const
TensorWrapper
&
A
,
bool
transa
,
const
TensorWrapper
&
B
,
bool
transb
,
TensorWrapper
&
D
,
TensorWrapper
&
bias
,
TensorWrapper
&
pre_gelu_out
,
TensorWrapper
&
workspace
,
bool
grad
,
bool
accumulate
,
bool
use_split_accumulator
,
TensorWrapper
&
rs_output
,
cudaStream_t
stream_main
)
{
int
ori_sms
=
_ub_comm
->
sms
;
_ub_comm
->
use_ce
=
_use_ce
;
_ub_comm
->
sms
=
_num_comm_sm
;
_ub_comm
->
cga_size
=
_cga_size
;
// Get communication and GEMM input chunk sizes
size_t
m
=
transa
?
A
.
size
(
0
)
:
A
.
size
(
1
);
size_t
k
=
transa
?
A
.
size
(
1
)
:
A
.
size
(
0
);
size_t
n_chunk
=
_ubufs
[
0
].
size
(
0
);
const
int
comm_bytes
=
_ubufs
[
0
].
numel
()
*
_ubufs
[
0
].
element_size
();
// Get input and workspace data pointers
size_t
input_chunk_size
=
n_chunk
*
k
;
size_t
output_chunk_size
=
n_chunk
*
m
;
size_t
workspace_size_chunk
=
workspace
.
numel
()
/
_stream_compute
.
size
();
// Catch up the main stream
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_compute
,
stream_main
));
for
(
size_t
i
=
0
;
i
<
_stream_send
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
i
],
_start_compute
,
0
));
}
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_compute
,
0
));
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_compute
[
i
],
_start_compute
,
0
));
}
// GEMM and send/recv chunks
for
(
int
i
=
0
;
i
<
_tp_size
;
i
++
)
{
// GEMM chunk
int
stream_id
=
i
%
_stream_compute
.
size
();
int
input_b_chunk_id
=
(
_tp_id
+
i
+
1
)
%
_tp_size
;
auto
input_b_chunk
=
get_tensor_chunk
(
B
,
input_b_chunk_id
*
input_chunk_size
,
{
n_chunk
,
k
});
auto
output_chunk
=
get_buffer_chunk_by_id
(
D
,
i
);
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]);
if
(
i
>
0
)
{
// P2P communication chunk
int
prev_stream_id
=
(
i
-
1
)
%
_stream_compute
.
size
();
int
send_offset
=
comm_bytes
*
(
i
-
1
);
int
recv_offset
=
comm_bytes
*
(
i
-
1
+
_tp_size
);
int
send_rank
=
(
_tp_id
+
i
)
%
_tp_size
+
_rank_round_tp
;
int
recv_rank
=
(
_tp_size
+
_tp_id
-
i
)
%
_tp_size
+
_rank_round_tp
;
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
prev_stream_id
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_send
[
prev_stream_id
],
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_recv
,
_start_comm
,
0
));
userbuffers_send
(
_ub_reg
,
send_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
send_rank
,
_stream_send
[
prev_stream_id
]);
userbuffers_recv
(
_ub_reg
,
send_offset
,
_ub_reg
,
recv_offset
,
comm_bytes
,
_ub_comm
,
recv_rank
,
_stream_recv
);
}
}
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_compute
,
_stream_compute
[
i
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_compute
,
0
));
}
for
(
size_t
i
=
0
;
i
<
_stream_compute
.
size
();
i
++
)
{
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_send
,
_stream_send
[
i
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
// Reduce GEMM output chunks
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
if
(
_ubuf
.
element_size
()
==
1
&&
rs_output
.
element_size
()
==
2
)
{
char
*
rs_output_ptr
=
reinterpret_cast
<
char
*>
(
rs_output
.
dptr
());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
D
.
dtype
(),
fp8_type
,
reduce_fp8_in_bf16_out
<
fp8_type
>
(
reduce_buf_ptr
,
rs_output_ptr
,
D
.
scale_inv
(),
_tp_size
,
_ubufs
[
0
].
numel
(),
stream_main
););
}
else
{
reduce_bf16
(
reduce_buf_ptr
,
rs_output_ptr
,
_tp_size
,
_ubufs
[
0
].
numel
(),
stream_main
);
}
_ub_comm
->
sms
=
ori_sms
;
}
}
// namespace transformer_engine
transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc
0 → 100644
View file @
5b6ef054
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "ipcsocket.h"
#include <errno.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#define IPC_MAX_MSGLEN 4096
void
ipc_warn
(
const
char
*
format
,
...)
{
char
buffer
[
IPC_MAX_MSGLEN
];
va_list
args
;
va_start
(
args
,
format
);
vsnprintf
(
buffer
,
IPC_MAX_MSGLEN
-
1
,
format
,
args
);
snprintf
(
buffer
+
strlen
(
buffer
),
IPC_MAX_MSGLEN
-
strlen
(
buffer
)
-
1
,
" : %s (%d)
\n
"
,
strerror
(
errno
),
errno
);
fflush
(
stdout
);
fputs
(
buffer
,
stderr
);
fflush
(
NULL
);
va_end
(
args
);
}
static
const
char
*
ipcSocketResultStrings
[
static_cast
<
int
>
(
ipcSocketNumResults
)]
=
{
"Success"
,
"Unhandled CUDA error"
,
"System error"
,
"Internal error"
,
"Invalid argument"
,
"Invalid usage"
,
"Remote error"
,
"In progress"
,
};
const
char
*
ipcSocketGetErrorString
(
ipcSocketResult_t
res
)
{
return
ipcSocketResultStrings
[
static_cast
<
int
>
(
res
)];
}
#define USE_ABSTRACT_SOCKET // Enable Linux abstract socket naming
#define IPC_SOCKNAME_STR "/tmp/ub-ipc-socket-%d-%lx"
/*
* Create a Unix Domain Socket
*/
ipcSocketResult_t
ipcSocketInit
(
IpcSocketHandle
*
handle
,
int
rank
,
uint64_t
hash
,
volatile
uint32_t
*
abortFlag
)
{
int
fd
=
-
1
;
struct
sockaddr_un
cliaddr
;
char
temp
[
IPC_SOCKNAME_LEN
]
=
""
;
if
(
handle
==
NULL
)
{
return
ipcSocketInternalError
;
}
handle
->
fd
=
-
1
;
handle
->
socketName
[
0
]
=
'\0'
;
if
((
fd
=
socket
(
AF_UNIX
,
SOCK_DGRAM
,
0
))
<
0
)
{
ipc_warn
(
"UDS: Socket creation error"
);
return
ipcSocketSystemError
;
}
bzero
(
&
cliaddr
,
sizeof
(
cliaddr
));
cliaddr
.
sun_family
=
AF_UNIX
;
// Create unique name for the socket.
size_t
len
=
snprintf
(
temp
,
IPC_SOCKNAME_LEN
,
IPC_SOCKNAME_STR
,
rank
,
hash
);
if
(
len
>
(
sizeof
(
cliaddr
.
sun_path
)
-
1
))
{
errno
=
ENAMETOOLONG
;
ipc_warn
(
"UDS: Cannot bind provided name to socket. Name too large"
);
return
ipcSocketInternalError
;
}
strncpy
(
cliaddr
.
sun_path
,
temp
,
len
);
#ifdef USE_ABSTRACT_SOCKET
cliaddr
.
sun_path
[
0
]
=
'\0'
;
// Linux abstract socket trick
#else
unlink
(
temp
);
#endif
if
(
bind
(
fd
,
(
struct
sockaddr
*
)
&
cliaddr
,
sizeof
(
cliaddr
))
<
0
)
{
ipc_warn
(
"UDS: Binding to socket %s failed"
,
temp
);
close
(
fd
);
return
ipcSocketSystemError
;
}
handle
->
fd
=
fd
;
strcpy
(
handle
->
socketName
,
temp
);
// NOLINT(*)
handle
->
abortFlag
=
abortFlag
;
// Mark socket as non-blocking
if
(
handle
->
abortFlag
)
{
int
flags
=
fcntl
(
fd
,
F_GETFL
);
fcntl
(
fd
,
F_SETFL
,
flags
|
O_NONBLOCK
);
}
return
ipcSocketSuccess
;
}
ipcSocketResult_t
ipcSocketGetFd
(
struct
IpcSocketHandle
*
handle
,
int
*
fd
)
{
if
(
handle
==
NULL
)
{
errno
=
EINVAL
;
ipc_warn
(
"ipcSocketSocketGetFd: pass NULL socket"
);
return
ipcSocketInvalidArgument
;
}
if
(
fd
)
*
fd
=
handle
->
fd
;
return
ipcSocketSuccess
;
}
ipcSocketResult_t
ipcSocketClose
(
IpcSocketHandle
*
handle
)
{
if
(
handle
==
NULL
)
{
return
ipcSocketInternalError
;
}
if
(
handle
->
fd
<=
0
)
{
return
ipcSocketSuccess
;
}
#ifndef USE_ABSTRACT_SOCKET
if
(
handle
->
socketName
[
0
]
!=
'\0'
)
{
unlink
(
handle
->
socketName
);
}
#endif
close
(
handle
->
fd
);
return
ipcSocketSuccess
;
}
ipcSocketResult_t
ipcSocketRecvMsg
(
IpcSocketHandle
*
handle
,
void
*
hdr
,
int
hdrLen
,
int
*
recvFd
)
{
struct
msghdr
msg
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
};
struct
iovec
iov
[
1
];
// Union to guarantee alignment requirements for control array
union
{
struct
cmsghdr
cm
;
char
control
[
CMSG_SPACE
(
sizeof
(
int
))];
}
control_un
;
struct
cmsghdr
*
cmptr
;
char
dummy_buffer
[
1
];
int
ret
;
msg
.
msg_control
=
control_un
.
control
;
msg
.
msg_controllen
=
sizeof
(
control_un
.
control
);
if
(
hdr
==
NULL
)
{
iov
[
0
].
iov_base
=
reinterpret_cast
<
void
*>
(
dummy_buffer
);
iov
[
0
].
iov_len
=
sizeof
(
dummy_buffer
);
}
else
{
iov
[
0
].
iov_base
=
hdr
;
iov
[
0
].
iov_len
=
hdrLen
;
}
msg
.
msg_iov
=
iov
;
msg
.
msg_iovlen
=
1
;
while
((
ret
=
recvmsg
(
handle
->
fd
,
&
msg
,
0
))
<=
0
)
{
if
(
errno
!=
EAGAIN
&&
errno
!=
EWOULDBLOCK
&&
errno
!=
EINTR
)
{
ipc_warn
(
"UDS: Receiving data over socket failed"
);
return
ipcSocketSystemError
;
}
if
(
handle
->
abortFlag
&&
*
handle
->
abortFlag
)
return
ipcSocketInternalError
;
}
if
(
recvFd
!=
NULL
)
{
if
(((
cmptr
=
CMSG_FIRSTHDR
(
&
msg
))
!=
NULL
)
&&
(
cmptr
->
cmsg_len
==
CMSG_LEN
(
sizeof
(
int
))))
{
if
((
cmptr
->
cmsg_level
!=
SOL_SOCKET
)
||
(
cmptr
->
cmsg_type
!=
SCM_RIGHTS
))
{
errno
=
EBADMSG
;
ipc_warn
(
"UDS: Receiving data over socket %s failed"
,
handle
->
socketName
);
return
ipcSocketSystemError
;
}
memmove
(
recvFd
,
CMSG_DATA
(
cmptr
),
sizeof
(
*
recvFd
));
}
else
{
errno
=
ENOMSG
;
ipc_warn
(
"UDS: Receiving data over socket %s failed"
,
handle
->
socketName
);
return
ipcSocketSystemError
;
}
}
else
{
errno
=
EINVAL
;
ipc_warn
(
"UDS: File descriptor pointer cannot be NULL"
);
return
ipcSocketInvalidArgument
;
}
return
ipcSocketSuccess
;
}
ipcSocketResult_t
ipcSocketRecvFd
(
IpcSocketHandle
*
handle
,
int
*
recvFd
)
{
return
ipcSocketRecvMsg
(
handle
,
NULL
,
0
,
recvFd
);
}
ipcSocketResult_t
ipcSocketSendMsg
(
IpcSocketHandle
*
handle
,
void
*
hdr
,
int
hdrLen
,
const
int
sendFd
,
int
rank
,
uint64_t
hash
)
{
struct
msghdr
msg
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
};
struct
iovec
iov
[
1
];
char
temp
[
IPC_SOCKNAME_LEN
];
union
{
struct
cmsghdr
cm
;
char
control
[
CMSG_SPACE
(
sizeof
(
int
))];
}
control_un
;
struct
cmsghdr
*
cmptr
;
char
dummy_buffer
[
1
];
struct
sockaddr_un
cliaddr
;
// Construct client address to send this shareable handle to
bzero
(
&
cliaddr
,
sizeof
(
cliaddr
));
cliaddr
.
sun_family
=
AF_UNIX
;
size_t
len
=
snprintf
(
temp
,
IPC_SOCKNAME_LEN
,
IPC_SOCKNAME_STR
,
rank
,
hash
);
if
(
len
>
(
sizeof
(
cliaddr
.
sun_path
)
-
1
))
{
errno
=
ENAMETOOLONG
;
ipc_warn
(
"UDS: Cannot connect to provided name for socket. Name too large"
);
return
ipcSocketInternalError
;
}
(
void
)
strncpy
(
cliaddr
.
sun_path
,
temp
,
len
);
#ifdef USE_ABSTRACT_SOCKET
cliaddr
.
sun_path
[
0
]
=
'\0'
;
// Linux abstract socket trick
#endif
if
(
sendFd
!=
-
1
)
{
msg
.
msg_control
=
control_un
.
control
;
msg
.
msg_controllen
=
sizeof
(
control_un
.
control
);
cmptr
=
CMSG_FIRSTHDR
(
&
msg
);
cmptr
->
cmsg_len
=
CMSG_LEN
(
sizeof
(
int
));
cmptr
->
cmsg_level
=
SOL_SOCKET
;
cmptr
->
cmsg_type
=
SCM_RIGHTS
;
memmove
(
CMSG_DATA
(
cmptr
),
&
sendFd
,
sizeof
(
sendFd
));
}
msg
.
msg_name
=
reinterpret_cast
<
void
*>
(
&
cliaddr
);
msg
.
msg_namelen
=
sizeof
(
struct
sockaddr_un
);
if
(
hdr
==
NULL
)
{
iov
[
0
].
iov_base
=
reinterpret_cast
<
void
*>
(
dummy_buffer
);
iov
[
0
].
iov_len
=
sizeof
(
dummy_buffer
);
}
else
{
iov
[
0
].
iov_base
=
hdr
;
iov
[
0
].
iov_len
=
hdrLen
;
}
msg
.
msg_iov
=
iov
;
msg
.
msg_iovlen
=
1
;
msg
.
msg_flags
=
0
;
ssize_t
sendResult
;
while
((
sendResult
=
sendmsg
(
handle
->
fd
,
&
msg
,
0
))
<
0
)
{
if
(
errno
!=
EAGAIN
&&
errno
!=
EWOULDBLOCK
&&
errno
!=
EINTR
)
{
ipc_warn
(
"UDS: Sending data over socket %s failed"
,
temp
);
return
ipcSocketSystemError
;
}
if
(
handle
->
abortFlag
&&
*
handle
->
abortFlag
)
return
ipcSocketInternalError
;
}
return
ipcSocketSuccess
;
}
ipcSocketResult_t
ipcSocketSendFd
(
IpcSocketHandle
*
handle
,
const
int
sendFd
,
int
rank
,
uint64_t
hash
)
{
return
ipcSocketSendMsg
(
handle
,
NULL
,
0
,
sendFd
,
rank
,
hash
);
}
transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h
0 → 100644
View file @
5b6ef054
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#define TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <memory.h>
#include <stdio.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <unistd.h>
typedef
enum
{
ipcSocketSuccess
=
0
,
ipcSocketUnhandledCudaError
=
1
,
ipcSocketSystemError
=
2
,
ipcSocketInternalError
=
3
,
ipcSocketInvalidArgument
=
4
,
ipcSocketInvalidUsage
=
5
,
ipcSocketRemoteError
=
6
,
ipcSocketInProgress
=
7
,
ipcSocketNumResults
=
8
}
ipcSocketResult_t
;
const
char
*
ipcSocketGetErrorString
(
ipcSocketResult_t
res
);
#define IPC_SOCKNAME_LEN 64
struct
IpcSocketHandle
{
int
fd
;
char
socketName
[
IPC_SOCKNAME_LEN
];
volatile
uint32_t
*
abortFlag
;
};
ipcSocketResult_t
ipcSocketInit
(
IpcSocketHandle
*
handle
,
int
rank
,
uint64_t
hash
,
volatile
uint32_t
*
abortFlag
);
ipcSocketResult_t
ipcSocketClose
(
IpcSocketHandle
*
handle
);
ipcSocketResult_t
ipcSocketGetFd
(
IpcSocketHandle
*
handle
,
int
*
fd
);
ipcSocketResult_t
ipcSocketRecvFd
(
IpcSocketHandle
*
handle
,
int
*
fd
);
ipcSocketResult_t
ipcSocketSendFd
(
IpcSocketHandle
*
handle
,
const
int
fd
,
int
rank
,
uint64_t
hash
);
#endif
/* TRANSFORMER_ENGINE_USERBUFFERS_IPCSOCKET_H */
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp
0 → 100644
View file @
5b6ef054
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <inttypes.h>
#include <math.h>
#include <sched.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <chrono>
#include <iostream>
#include <map>
#include <utility>
#include "common/util/cuda_driver.h"
#include "common/util/cuda_nvml.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "common/util/system.h"
#include "ipcsocket.h"
#include "userbuffers.h"
#ifdef NVTE_UB_WITH_MPI
static
MPI_Comm
EXT_COMM_WORLD
=
MPI_COMM_WORLD
;
static
MPI_Comm
EXT_COMM_INTRA
;
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%d in function %s: %s", __FILE__, __LINE__, \
__func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
void
ub_mpi_allgather
(
void
*
globaldata
,
size_t
globalbytes
,
void
*
localdata
,
size_t
localbytes
,
ExtComm
comm
)
{
int
numranks
;
UB_MPI_CHECK
(
MPI_Comm_size
(
comm
,
&
numranks
));
assert
(
globalbytes
==
numranks
*
localbytes
);
UB_MPI_CHECK
(
MPI_Allgather
(
localdata
,
localbytes
,
MPI_BYTE
,
globaldata
,
localbytes
,
MPI_BYTE
,
comm
));
}
void
ub_mpi_barrier
(
ExtComm
comm
)
{
UB_MPI_CHECK
(
MPI_Barrier
(
comm
));
}
#else
#define EXT_COMM_WORLD "world"
#define EXT_COMM_INTRA "intra"
#endif
#define MULTICAST_GB_TOTAL 512
#if CUDART_VERSION < 12030
// MNNVL: FABRIC handle support lifted from CUDA 12.3
#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL)
#define CU_IPC_HANDLE_SIZE 64
typedef
struct
CUmemFabricHandle_st
{
unsigned
char
data
[
CU_IPC_HANDLE_SIZE
];
}
CUmemFabricHandle_v1
;
typedef
CUmemFabricHandle_v1
CUmemFabricHandle
;
#endif
int
stringCmp
(
const
void
*
a
,
const
void
*
b
)
{
return
strcmp
((
const
char
*
)
a
,
(
const
char
*
)
b
);
}
#define IPCCHECK(cmd) \
do { \
ipcSocketResult_t r = cmd; \
if (r != ipcSocketSuccess) { \
printf("Failed, UDS error %s:%d '%s'\n", __FILE__, __LINE__, ipcSocketGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define IPCCHECKGOTO(call, RES, label) \
do { \
RES = call; \
if (RES != ipcSocketSuccess && RES != ipcSocketInProgress) { \
goto label; \
} \
} while (0);
bool
has_mnnvl_fabric
(
int
device_id
)
{
#if CUDA_VERSION < 12040
if
(
getenv
(
"NVTE_UBDEBUG"
))
{
printf
(
"TransformerEngine does not support multi-node NVLINK "
"since it was not built with CUDA version >= 12.4.
\n
"
);
}
return
false
;
#else
bool
mnnvl_fabric_support
=
false
;
CUdevice
dev
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGet
,
&
dev
,
device_id
);
int
fabric_handle_supported
=
0
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuDeviceGetAttribute
,
&
fabric_handle_supported
,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED
,
dev
);
if
(
fabric_handle_supported
)
{
NVTE_CALL_CHECK_CUDA_NVML
(
nvmlInit_v2
);
nvmlDevice_t
local_device
;
NVTE_CALL_CHECK_CUDA_NVML
(
nvmlDeviceGetHandleByIndex_v2
,
device_id
,
&
local_device
);
nvmlGpuFabricInfoV_t
fabricInfo
=
{};
fabricInfo
.
version
=
nvmlGpuFabricInfo_v2
;
fabricInfo
.
clusterUuid
[
0
]
=
'\0'
;
NVTE_CALL_CHECK_CUDA_NVML
(
nvmlDeviceGetGpuFabricInfoV
,
local_device
,
&
fabricInfo
);
NVTE_CALL_CHECK_CUDA_NVML
(
nvmlShutdown
);
if
(
fabricInfo
.
state
>=
NVML_GPU_FABRIC_STATE_COMPLETED
&&
fabricInfo
.
clusterUuid
[
0
]
!=
'\0'
)
{
mnnvl_fabric_support
=
true
;
}
}
if
(
getenv
(
"NVTE_UBDEBUG"
))
{
if
(
mnnvl_fabric_support
)
{
printf
(
"MNNVL NVLINK is supported on this platform.
\n
"
);
}
else
{
printf
(
"MNNVL NVLINK is not supported on this platform.
\n
"
);
}
}
return
mnnvl_fabric_support
;
#endif
}
int
create_communicator_grouped2
(
communicator
**
comm
,
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
ExtAllgatherOp
ext_allgather
,
ExtBarrierOp
ext_barrier
,
int
pipegpus
,
int
pipenodes
,
int
tensorgpus
,
int
tensornodes
)
{
*
comm
=
new
communicator
();
(
*
comm
)
->
comm_world
=
EXT_COMM_WORLD
;
(
*
comm
)
->
_allgather
=
ext_allgather
;
(
*
comm
)
->
_barrier
=
ext_barrier
;
(
*
comm
)
->
nranks
=
numranks
;
(
*
comm
)
->
myrank
=
myrank
;
(
*
comm
)
->
free_region
=
0
;
(
*
comm
)
->
launch_mode
=
NVTE_LAUNCH_GPU
|
NVTE_LAUNCH_CPU
;
int
cur_dev
,
ndev
;
cudaDeviceProp
device_prop
;
NVTE_CHECK_CUDA
(
cudaGetDevice
(
&
cur_dev
));
NVTE_CHECK_CUDA
(
cudaGetDeviceCount
(
&
ndev
));
NVTE_CHECK_CUDA
(
cudaGetDeviceProperties
(
&
device_prop
,
cur_dev
));
(
*
comm
)
->
sm_arch
=
device_prop
.
major
;
// (*comm)->use_rr_kernel = device_prop.major == 8;
(
*
comm
)
->
use_rr_kernel
=
0
;
(
*
comm
)
->
push
=
1
;
(
*
comm
)
->
use_ce
=
0
;
(
*
comm
)
->
cga_size
=
2
;
for
(
int
i
=
0
;
i
<
userbuffers_op_types
;
i
++
)
(
*
comm
)
->
basecounter
[
i
]
=
0
;
int
device_clock
=
0
;
// 110 sec wait time by default
int
sec_timeout
=
getenv
(
"UB_TIMEOUT"
)
?
atoi
(
getenv
(
"UB_TIMEOUT"
))
:
110
;
NVTE_CHECK_CUDA
(
cudaDeviceGetAttribute
(
&
device_clock
,
cudaDevAttrClockRate
,
cur_dev
));
(
*
comm
)
->
ub_timeout
=
1000ull
*
device_clock
*
sec_timeout
;
if
((
*
comm
)
->
myrank
==
0
)
{
printf
(
"UB_TIMEOUT is set to %d sec, %"
PRIu64
" cycles, freq: %dkhz
\n
"
,
sec_timeout
,
(
*
comm
)
->
ub_timeout
,
device_clock
);
}
(
*
comm
)
->
comm_intra
=
EXT_COMM_INTRA
;
(
*
comm
)
->
nvrank
=
mylocal
;
(
*
comm
)
->
nvsize
=
numlocal
;
cpu_set_t
cpuset
;
CPU_ZERO
(
&
cpuset
);
int
core
;
if
(
mylocal
==
0
)
core
=
50
;
if
(
mylocal
==
1
)
core
=
58
;
if
(
mylocal
==
2
)
core
=
18
;
if
(
mylocal
==
3
)
core
=
26
;
if
(
mylocal
==
4
)
core
=
114
;
if
(
mylocal
==
5
)
core
=
122
;
if
(
mylocal
==
6
)
core
=
82
;
if
(
mylocal
==
7
)
core
=
90
;
CPU_SET
(
core
,
&
cpuset
);
if
(
!
getenv
(
"NVTE_NODOUBLE"
))
{
if
(
core
>
128
)
CPU_SET
(
core
-
128
,
&
cpuset
);
else
CPU_SET
(
core
+
128
,
&
cpuset
);
}
if
(
getenv
(
"NVTE_DOPIN"
))
pthread_setaffinity_np
(
pthread_self
(),
sizeof
(
cpu_set_t
),
&
cpuset
);
if
(
ndev
==
numlocal
)
{
// all visible devices
if
(
cur_dev
!=
mylocal
)
printf
(
"%d: device used %d[%d] ,resetting device to %d
\n
"
,
myrank
,
cur_dev
,
ndev
,
mylocal
);
NVTE_CHECK_CUDA
(
cudaSetDevice
(
mylocal
));
}
(
*
comm
)
->
mydev
=
cur_dev
;
// FIXME need to check that numlocal is multiple of pipegpus x tensorgpus
// ar1 is data
int
divgpus
=
pipegpus
*
tensorgpus
;
int
datagpus
=
numlocal
/
divgpus
;
(
*
comm
)
->
ar_nvsize
=
datagpus
;
(
*
comm
)
->
ar_firstgpu
=
mylocal
-
((
mylocal
/
tensorgpus
)
%
datagpus
)
*
tensorgpus
;
(
*
comm
)
->
ar_nvrank
=
(
mylocal
-
(
*
comm
)
->
ar_firstgpu
)
/
tensorgpus
;
// ar2 is tensor
(
*
comm
)
->
ar2_nvsize
=
tensorgpus
;
(
*
comm
)
->
ar2_firstgpu
=
mylocal
-
mylocal
%
tensorgpus
;
(
*
comm
)
->
ar2_nvrank
=
mylocal
-
(
*
comm
)
->
ar2_firstgpu
;
// ar2 has step equal to ar_nvsize
int
allnodes
=
numranks
/
numlocal
;
int
nodeid
=
myrank
/
numlocal
;
(
*
comm
)
->
num_nodes
=
numnodes
;
(
*
comm
)
->
my_node
=
mynode
;
#define NBUF 2
#if CUDART_VERSION >= 12010
bool
mnnvl_fabric
=
has_mnnvl_fabric
(
cur_dev
);
if
(
!
transformer_engine
::
getenv
<
bool
>
(
"UB_SKIPMC"
)
&&
transformer_engine
::
cuda
::
supports_multicast
()
&&
(
*
comm
)
->
ar2_nvsize
>
1
)
{
// multicast init only for TP ops (____2 operations)
size_t
mc_maxsize
=
MULTICAST_GB_TOTAL
*
(
1ull
<<
30
);
(
*
comm
)
->
mc_offset
=
0
;
(
*
comm
)
->
use_mc
=
1
;
size_t
gran
;
CUmulticastObjectProp
mcProp
=
{};
mcProp
.
numDevices
=
(
*
comm
)
->
ar2_nvsize
;
mcProp
.
size
=
(
*
comm
)
->
mc_maxsize
;
mcProp
.
handleTypes
=
mnnvl_fabric
?
CU_MEM_HANDLE_TYPE_FABRIC
:
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMulticastGetGranularity
,
&
gran
,
&
mcProp
,
static_cast
<
CUmemAllocationGranularity_flags
>
(
CU_MULTICAST_GRANULARITY_RECOMMENDED
));
mc_maxsize
=
((
mc_maxsize
+
gran
-
1
)
/
gran
)
*
gran
;
mcProp
.
size
=
mc_maxsize
;
(
*
comm
)
->
mc_maxsize
=
mc_maxsize
;
if
((
*
comm
)
->
ar2_nvrank
==
0
)
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMulticastCreate
,
&
(
*
comm
)
->
mc_handle
,
&
mcProp
);
if
(
mnnvl_fabric
)
{
CUmemFabricHandle
*
exphndl
=
reinterpret_cast
<
CUmemFabricHandle
*>
(
malloc
(
sizeof
(
CUmemFabricHandle
)));
CUmemFabricHandle
*
tmphndl
=
reinterpret_cast
<
CUmemFabricHandle
*>
(
malloc
(
sizeof
(
CUmemFabricHandle
)));
CUmemFabricHandle
*
exphndls
;
NVTE_CHECK_CUDA
(
cudaMallocHost
(
&
exphndls
,
(
*
comm
)
->
nvsize
*
sizeof
(
CUmemFabricHandle
)));
if
((
*
comm
)
->
ar2_nvrank
==
0
)
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemExportToShareableHandle
,
static_cast
<
void
*>
(
tmphndl
),
(
*
comm
)
->
mc_handle
,
CU_MEM_HANDLE_TYPE_FABRIC
,
0
);
for
(
int
grp
=
0
;
grp
<
(
*
comm
)
->
ar_nvsize
;
grp
++
)
{
// we do N broadcasts for N TP groups in NVL domain
int
root
=
grp
*
(
*
comm
)
->
ar2_nvsize
;
// It just needs to be a bcast but reuse existing allgather comm
(
*
comm
)
->
_allgather
(
reinterpret_cast
<
void
*>
(
exphndls
),
(
*
comm
)
->
nvsize
*
sizeof
(
CUmemFabricHandle
),
reinterpret_cast
<
void
*>
(
tmphndl
),
sizeof
(
CUmemFabricHandle
),
(
*
comm
)
->
comm_intra
);
//save data if brodcast was from rank 0 in our group
if
((
*
comm
)
->
ar2_firstgpu
==
root
)
memcpy
(
exphndl
,
exphndls
+
root
,
sizeof
(
CUmemFabricHandle
));
}
if
((
*
comm
)
->
ar2_nvrank
!=
0
)
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemImportFromShareableHandle
,
&
(
*
comm
)
->
mc_handle
,
reinterpret_cast
<
void
*>
(
exphndl
),
CU_MEM_HANDLE_TYPE_FABRIC
);
free
(
exphndl
);
free
(
tmphndl
);
NVTE_CHECK_CUDA
(
cudaFreeHost
(
exphndls
));
}
else
{
// Broadcast the a POSIX file descriptor from the local root rank to other local ranks.
// NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the
// file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use Unix domain sockets for the kernel to
// recreate the correct file descriptor on every receiving rank.
int
fd
;
volatile
uint32_t
abortFlag
=
0
;
IpcSocketHandle
ipcSock
=
{
0
};
uint64_t
opId
=
0xdeadcafe0000
+
(
*
comm
)
->
my_node
;
ipcSocketResult_t
ret
=
ipcSocketSuccess
;
IPCCHECK
(
ipcSocketInit
(
&
ipcSock
,
(
*
comm
)
->
ar2_nvrank
,
(
uint64_t
)
opId
,
&
abortFlag
));
(
*
comm
)
->
_barrier
((
*
comm
)
->
comm_world
);
if
((
*
comm
)
->
ar2_nvrank
==
0
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemExportToShareableHandle
,
reinterpret_cast
<
void
*>
(
&
fd
),
(
*
comm
)
->
mc_handle
,
static_cast
<
CUmemAllocationHandleType
>
(
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
),
(
uint64_t
)
0
);
for
(
int
p
=
1
;
p
<
(
*
comm
)
->
ar2_nvsize
;
p
++
)
{
(
*
comm
)
->
_barrier
((
*
comm
)
->
comm_intra
);
IPCCHECKGOTO
(
ipcSocketSendFd
(
&
ipcSock
,
fd
,
p
,
(
uint64_t
)
opId
),
ret
,
error
);
}
}
else
{
for
(
int
p
=
1
;
p
<
(
*
comm
)
->
ar2_nvsize
;
p
++
)
{
(
*
comm
)
->
_barrier
((
*
comm
)
->
comm_intra
);
if
((
*
comm
)
->
ar2_nvrank
==
p
)
IPCCHECKGOTO
(
ipcSocketRecvFd
(
&
ipcSock
,
&
fd
),
ret
,
error
);
}
}
error:
if
((
*
comm
)
->
ar2_nvrank
!=
0
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemImportFromShareableHandle
,
&
(
*
comm
)
->
mc_handle
,
reinterpret_cast
<
void
*>
(
fd
),
static_cast
<
CUmemAllocationHandleType
>
(
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
));
}
IPCCHECK
(
ipcSocketClose
(
&
ipcSock
));
close
(
fd
);
}
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMulticastAddDevice
,
(
*
comm
)
->
mc_handle
,
(
CUdeviceptr
)(
*
comm
)
->
mydev
);
CUdeviceptr
mc_va
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemAddressReserve
,
&
mc_va
,
mc_maxsize
,
(
size_t
)
0
,
(
CUdeviceptr
)
0U
,
(
uint64_t
)
0
);
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemMap
,
mc_va
,
mc_maxsize
,
(
size_t
)
0
,
(
*
comm
)
->
mc_handle
,
(
uint64_t
)
0
);
CUmemAccessDesc
accessDesc
=
{};
accessDesc
.
location
.
type
=
CU_MEM_LOCATION_TYPE_DEVICE
;
accessDesc
.
location
.
id
=
(
*
comm
)
->
mydev
;
accessDesc
.
flags
=
CU_MEM_ACCESS_FLAGS_PROT_READWRITE
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemSetAccess
,
mc_va
,
mc_maxsize
,
const_cast
<
CUmemAccessDesc
*>
(
&
accessDesc
),
(
size_t
)
1
);
(
*
comm
)
->
mc_baseptr
=
reinterpret_cast
<
void
*>
(
mc_va
);
(
*
comm
)
->
_barrier
((
*
comm
)
->
comm_world
);
if
(
!
(
*
comm
)
->
myrank
)
printf
(
"MC initialized succesfully, window size = %ld
\n
"
,
mc_maxsize
);
}
else
{
#endif
if
(
!
(
*
comm
)
->
myrank
)
printf
(
"MC NOT initialized and used
\n
"
);
(
*
comm
)
->
mc_maxsize
=
0
;
(
*
comm
)
->
mc_offset
=
0
;
(
*
comm
)
->
use_mc
=
0
;
#if CUDART_VERSION >= 12010
}
#endif
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
register_user_buffer_collective
(
&
((
*
comm
)
->
gpu_ptrs
),
LOCALSIZE
,
*
comm
,
true
);
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
send_id
,
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
recv_id
,
NVTE_MAX_REGIONS
*
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
send_id
,
0
,
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
recv_id
,
0
,
NVTE_MAX_REGIONS
*
(
*
comm
)
->
nranks
*
sizeof
(
int
)));
(
*
comm
)
->
sms
=
16
;
(
*
comm
)
->
threads
=
1024
;
#define GPU_PAGE_SHIFT 16
#define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT)
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
NVTE_CHECK_CUDA
(
cudaMalloc
(
&
(
*
comm
)
->
flags
,
2
*
GPU_PAGE_SIZE
));
NVTE_CHECK_CUDA
(
cudaMemset
((
*
comm
)
->
flags
,
0
,
2
*
GPU_PAGE_SIZE
));
(
*
comm
)
->
flags
=
reinterpret_cast
<
int
*>
(((
CUdeviceptr
)(
*
comm
)
->
flags
+
GPU_PAGE_SIZE
-
1
)
&
GPU_PAGE_MASK
);
using
namespace
std
;
sched_param
param
;
pthread_attr_t
attr
;
pthread_attr_init
(
&
attr
);
pthread_attr_getschedparam
(
&
attr
,
&
param
);
param
.
sched_priority
=
sched_get_priority_max
(
SCHED_FIFO
);
pthread_attr_setschedparam
(
&
attr
,
&
param
);
if
(
getenv
(
"NVTE_UBDEBUG"
))
printf
(
"%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP x%d TPGROUP "
"%dx%d
\n
"
,
myrank
,
numranks
,
myrank
/
numlocal
,
myrank
%
numlocal
,
(
*
comm
)
->
my_node
,
(
*
comm
)
->
ar_nvrank
,
(
*
comm
)
->
my_node
,
(
*
comm
)
->
ar2_nvrank
,
(
*
comm
)
->
ar_nvsize
,
(
*
comm
)
->
num_nodes
,
(
*
comm
)
->
ar2_nvsize
);
fflush
(
NULL
);
return
0
;
}
int
create_communicator_grouped
(
communicator
**
comm
,
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
ExtAllgatherOp
ext_allgather
,
ExtBarrierOp
ext_barrier
,
int
pipegpus
,
int
pipenodes
)
{
return
create_communicator_grouped2
(
comm
,
myrank
,
numranks
,
mylocal
,
numlocal
,
mynode
,
numnodes
,
ext_allgather
,
ext_barrier
,
pipegpus
,
pipenodes
,
1
,
1
);
}
int
create_communicator
(
communicator
**
comm
,
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
ExtAllgatherOp
ext_allgather
,
ExtBarrierOp
ext_barrier
)
{
return
create_communicator_grouped2
(
comm
,
myrank
,
numranks
,
mylocal
,
numlocal
,
mynode
,
numnodes
,
ext_allgather
,
ext_barrier
,
1
,
1
,
1
,
1
);
}
int
create_communicator_grouped2_mpi
(
communicator
**
comm
,
int
pipegpus
,
int
pipenodes
,
int
tensorgpus
,
int
tensornodes
)
{
#ifdef NVTE_UB_WITH_MPI
// get global numbers
int
myrank
,
numranks
;
UB_MPI_CHECK
(
MPI_Comm_rank
(
EXT_COMM_WORLD
,
&
myrank
));
UB_MPI_CHECK
(
MPI_Comm_size
(
EXT_COMM_WORLD
,
&
numranks
));
int
mylocal
,
numlocal
;
UB_MPI_CHECK
(
MPI_Comm_split
(
EXT_COMM_WORLD
,
myrank
/
tensorgpus
,
myrank
,
&
EXT_COMM_INTRA
));
UB_MPI_CHECK
(
MPI_Comm_rank
(
EXT_COMM_INTRA
,
&
mylocal
));
UB_MPI_CHECK
(
MPI_Comm_size
(
EXT_COMM_INTRA
,
&
numlocal
));
// find internode numbers and make internode communicator
NVTE_CHECK_CUDA
(
cudaFree
(
0
));
int
mynode
,
numnodes
;
mynode
=
myrank
/
numlocal
;
numnodes
=
numranks
/
numlocal
;
// finally call the abstracted constructor with MPI info
return
create_communicator_grouped2
(
comm
,
myrank
,
numranks
,
mylocal
,
numlocal
,
mynode
,
numnodes
,
&
ub_mpi_allgather
,
&
ub_mpi_barrier
,
pipegpus
,
pipenodes
,
tensorgpus
,
tensornodes
);
#else
NVTE_ERROR
(
std
::
string
(
"Bootstrapping Userbuffers with MPI requires building"
)
+
std
::
string
(
"Transformer Engine with NVTE_UB_WITH_MPI=1 and MPI_HOME=/path/to/mpi"
));
#endif
}
int
create_communicator_grouped_mpi
(
communicator
**
comm
,
int
pipegpus
,
int
pipenodes
)
{
return
create_communicator_grouped2_mpi
(
comm
,
pipegpus
,
pipenodes
,
1
,
1
);
}
int
create_communicator_mpi
(
communicator
**
comm
)
{
return
create_communicator_grouped2_mpi
(
comm
,
1
,
1
,
1
,
1
);
}
void
destroy_communicator
(
communicator
*
comm
)
{
for
(
int
hndl
=
0
;
hndl
<
comm
->
free_region
;
hndl
++
)
{
if
(
comm
->
use_mc
&&
comm
->
mem_dealloc
[
hndl
])
{
for
(
int
rank
=
0
;
rank
<
comm
->
nvsize
;
rank
++
)
{
if
(
rank
==
comm
->
nvrank
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemRelease
,
comm
->
uchandles
[
hndl
][
rank
]);
}
else
{
comm
->
uchandles
[
hndl
][
rank
]
=
0
;
}
}
free
(
reinterpret_cast
<
void
*>
(
comm
->
uchandles
[
hndl
]));
}
else
{
for
(
int
rank
=
0
;
rank
<
comm
->
nvsize
;
rank
++
)
{
if
(
rank
!=
comm
->
nvrank
)
{
cudaIpcCloseMemHandle
(
comm
->
peer_ptr
[
hndl
][
rank
]);
}
else
if
(
comm
->
mem_dealloc
[
hndl
])
{
NVTE_CHECK_CUDA
(
cudaFree
(
comm
->
peer_ptr
[
hndl
][
rank
]));
}
else
{
comm
->
peer_ptr
[
hndl
][
rank
]
=
nullptr
;
// remove reference to external buffer
}
}
}
free
(
comm
->
peer_ptr
[
hndl
]);
comm
->
mem_ptr
[
hndl
]
=
nullptr
;
}
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
recv_id
));
cudaFree
(
reinterpret_cast
<
void
*>
(
comm
->
send_id
));
if
(
comm
->
use_mc
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemRelease
,
comm
->
mc_handle
);
}
delete
comm
;
}
void
destroy_communicator_mpi
(
communicator
*
comm
)
{
#ifdef NVTE_UB_WITH_MPI
MPI_Comm_free
(
static_cast
<
MPI_Comm
*>
(
&
(
comm
->
comm_intra
)));
destroy_communicator
(
comm
);
#else
NVTE_ERROR
(
std
::
string
(
"Communicator is not bootstrapped with MPI and "
)
+
std
::
string
(
"can only be deallocated with destroy_communicator()."
));
#endif
}
int
register_user_buffer_collective
(
void
**
gpubuff
,
size_t
bytes
,
communicator
*
comm
,
bool
alloc
)
{
if
(
comm
->
free_region
>
NVTE_MAX_REGIONS
)
return
-
1
;
int
hndl
=
comm
->
free_region
;
comm
->
peer_ptr
[
hndl
]
=
reinterpret_cast
<
void
**>
(
malloc
(
sizeof
(
void
*
)
*
(
comm
->
nvsize
)));
size_t
aligned_size
=
bytes
;
comm
->
memflags
[
hndl
]
=
0
;
comm
->
mem_dealloc
[
hndl
]
=
alloc
;
#if CUDART_VERSION >= 12010
if
(
comm
->
use_mc
&&
alloc
)
{
bool
mnnvl_fabric
=
has_mnnvl_fabric
(
comm
->
mydev
);
int
nranks
=
comm
->
nvsize
;
// total GPUs in NVLINK domain
int
myrank
=
comm
->
nvrank
;
void
**
remptrs
=
reinterpret_cast
<
void
**>
(
malloc
(
nranks
*
sizeof
(
void
*
)));
CUmemAllocationProp
prop
=
{};
prop
.
type
=
CU_MEM_ALLOCATION_TYPE_PINNED
;
prop
.
location
.
type
=
CU_MEM_LOCATION_TYPE_DEVICE
;
prop
.
location
.
id
=
comm
->
mydev
;
prop
.
requestedHandleTypes
=
mnnvl_fabric
?
CU_MEM_HANDLE_TYPE_FABRIC
:
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
;
size_t
granularity
=
0
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemGetAllocationGranularity
,
&
granularity
,
&
prop
,
static_cast
<
CUmemAllocationGranularity_flags
>
(
CU_MULTICAST_GRANULARITY_MINIMUM
));
// MPI_Allreduce MAX of granularity check
aligned_size
=
(
bytes
+
granularity
-
1
)
/
granularity
*
granularity
;
if
(
comm
->
use_mc
)
{
CUmulticastObjectProp
mcProp
=
{};
mcProp
.
numDevices
=
nranks
;
mcProp
.
size
=
aligned_size
;
mcProp
.
handleTypes
=
prop
.
requestedHandleTypes
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMulticastGetGranularity
,
&
granularity
,
&
mcProp
,
static_cast
<
CUmemAllocationGranularity_flags
>
(
CU_MULTICAST_GRANULARITY_MINIMUM
));
aligned_size
=
(
aligned_size
+
granularity
-
1
)
/
granularity
*
granularity
;
}
prop
.
location
.
id
=
comm
->
mydev
;
comm
->
uchandles
[
hndl
]
=
reinterpret_cast
<
CUmemGenericAllocationHandle
*>
(
malloc
(
nranks
*
sizeof
(
CUmemGenericAllocationHandle
)));
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemCreate
,
&
(
comm
->
uchandles
[
hndl
][
myrank
]),
aligned_size
,
&
prop
,
(
uint64_t
)
0
);
if
(
mnnvl_fabric
)
{
CUmemFabricHandle
*
exphndl
;
CUmemFabricHandle
myhndl
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemExportToShareableHandle
,
&
myhndl
,
comm
->
uchandles
[
hndl
][
myrank
],
CU_MEM_HANDLE_TYPE_FABRIC
,
0
);
NVTE_CHECK_CUDA
(
cudaMallocHost
(
&
exphndl
,
comm
->
nvsize
*
sizeof
(
CUmemFabricHandle
)));
comm
->
_allgather
(
reinterpret_cast
<
void
*>
(
exphndl
),
comm
->
nvsize
*
sizeof
(
CUmemFabricHandle
),
reinterpret_cast
<
void
*>
(
&
myhndl
),
sizeof
(
CUmemFabricHandle
),
comm
->
comm_intra
);
for
(
int
p
=
0
;
p
<
nranks
;
p
++
)
if
(
p
!=
myrank
)
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemImportFromShareableHandle
,
&
comm
->
uchandles
[
hndl
][
p
],
reinterpret_cast
<
void
*>
(
&
exphndl
[
p
]),
CU_MEM_HANDLE_TYPE_FABRIC
);
NVTE_CHECK_CUDA
(
cudaFreeHost
(
exphndl
));
}
else
{
int
*
peerfd
=
reinterpret_cast
<
int
*>
(
malloc
(
nranks
*
sizeof
(
int
)));
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemExportToShareableHandle
,
reinterpret_cast
<
void
*>
(
&
peerfd
[
myrank
]),
comm
->
uchandles
[
hndl
][
myrank
],
static_cast
<
CUmemAllocationHandleType
>
(
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
),
(
uint64_t
)
0
);
volatile
uint32_t
abortFlag
=
0
;
IpcSocketHandle
ipcSock
=
{
0
};
uint64_t
opId
=
0xdeadcafe0000
+
comm
->
my_node
;
ipcSocketResult_t
ret
=
ipcSocketSuccess
;
// All-gather POSIX file descriptors across local ranks
IPCCHECK
(
ipcSocketInit
(
&
ipcSock
,
myrank
,
(
uint64_t
)
opId
,
&
abortFlag
));
for
(
int
p
=
1
;
p
<
nranks
;
p
++
)
{
int
send_to
=
(
myrank
+
p
)
%
nranks
;
int
recv_from
=
(
myrank
+
nranks
-
p
)
%
nranks
;
comm
->
_barrier
(
comm
->
comm_intra
);
IPCCHECKGOTO
(
ipcSocketSendFd
(
&
ipcSock
,
peerfd
[
myrank
],
send_to
,
(
uint64_t
)
opId
),
ret
,
error
);
IPCCHECKGOTO
(
ipcSocketRecvFd
(
&
ipcSock
,
&
peerfd
[
recv_from
]),
ret
,
error
);
}
error:
IPCCHECK
(
ipcSocketClose
(
&
ipcSock
));
for
(
int
p
=
0
;
p
<
nranks
;
p
++
)
{
if
(
p
!=
myrank
)
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemImportFromShareableHandle
,
&
comm
->
uchandles
[
hndl
][
p
],
reinterpret_cast
<
void
*>
(
peerfd
[
p
]),
static_cast
<
CUmemAllocationHandleType
>
(
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
));
close
(
peerfd
[
p
]);
}
free
(
peerfd
);
}
CUdeviceptr
ptr
;
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemAddressReserve
,
&
ptr
,
(
size_t
)(
aligned_size
*
nranks
),
(
size_t
)
0
,
(
CUdeviceptr
)
0
,
(
uint64_t
)
0
);
comm
->
ucbase_ptr
[
hndl
]
=
reinterpret_cast
<
void
*>
(
ptr
);
CUmemAccessDesc
accessDesc
=
{};
accessDesc
.
location
.
type
=
CU_MEM_LOCATION_TYPE_DEVICE
;
accessDesc
.
flags
=
CU_MEM_ACCESS_FLAGS_PROT_READWRITE
;
accessDesc
.
location
.
id
=
comm
->
mydev
;
for
(
int
i
=
0
;
i
<
nranks
;
i
++
)
{
remptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
ptr
+
(
aligned_size
*
i
));
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemMap
,
reinterpret_cast
<
CUdeviceptr
>
(
remptrs
[
i
]),
aligned_size
,
(
size_t
)
0
,
comm
->
uchandles
[
hndl
][
i
],
(
uint64_t
)
0
);
if
(
i
==
comm
->
nvrank
)
{
if
(
hndl
)
*
gpubuff
=
remptrs
[
i
];
else
comm
->
gpu_ptrs
=
remptrs
[
i
];
}
comm
->
peer_ptr
[
hndl
][
i
]
=
remptrs
[
i
];
}
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMemSetAccess
,
ptr
,
(
size_t
)(
aligned_size
*
nranks
),
const_cast
<
CUmemAccessDesc
*>
(
&
accessDesc
),
(
size_t
)
1
);
if
(
hndl
==
0
)
NVTE_CHECK_CUDA
(
cudaMemset
(
comm
->
gpu_ptrs
,
0
,
aligned_size
));
NVTE_CHECK_CUDA
(
cudaMemcpy
((
reinterpret_cast
<
char
*>
(
comm
->
gpu_ptrs
))
+
(
hndl
*
nranks
*
sizeof
(
void
*
)),
remptrs
,
nranks
*
sizeof
(
void
*
),
cudaMemcpyHostToDevice
));
free
(
remptrs
);
comm
->
memflags
[
hndl
]
=
NVTE_UB_MEM_UC_CONTIG
|
NVTE_UB_MEM_ALLOCATED
;
if
(
comm
->
use_mc
&&
comm
->
mc_maxsize
>=
comm
->
mc_offset
+
aligned_size
)
{
NVTE_CALL_CHECK_CUDA_DRIVER
(
cuMulticastBindMem
,
comm
->
mc_handle
,
comm
->
mc_offset
,
comm
->
uchandles
[
hndl
][
myrank
],
(
size_t
)
0
/*memOffset*/
,
aligned_size
,
(
uint64_t
)
0
);
comm
->
memflags
[
hndl
]
|=
NVTE_UB_MEM_MC_CREATED
;
comm
->
mc_ptr
[
hndl
]
=
reinterpret_cast
<
char
*>
(
comm
->
mc_baseptr
)
+
comm
->
mc_offset
;
comm
->
mc_offset
+=
aligned_size
;
}
else
if
(
!
comm
->
myrank
)
{
printf
(
"UB: warning region %d size %ld MB registered without MC access
\n
"
,
hndl
,
aligned_size
/
1024
/
1024
);
}
}
else
{
#endif
if
(
alloc
)
{
NVTE_CHECK_CUDA
(
cudaMalloc
(
gpubuff
,
bytes
));
NVTE_CHECK_CUDA
(
cudaMemset
(
*
gpubuff
,
0
,
bytes
));
}
NVTE_CHECK
(
comm
->
nvsize
<=
8
,
"CUDA IPC supports only up to 8 GPUs in an NVLink domain."
);
cudaIpcMemHandle_t
memhndl
;
NVTE_CHECK_CUDA
(
cudaIpcGetMemHandle
(
&
memhndl
,
*
gpubuff
));
cudaIpcMemHandle_t
*
tmp
=
reinterpret_cast
<
cudaIpcMemHandle_t
*>
(
malloc
(
comm
->
nvsize
*
sizeof
(
cudaIpcMemHandle_t
)));
comm
->
_allgather
(
reinterpret_cast
<
void
*>
(
tmp
),
comm
->
nvsize
*
sizeof
(
cudaIpcMemHandle_t
),
reinterpret_cast
<
void
*>
(
&
memhndl
),
sizeof
(
cudaIpcMemHandle_t
),
comm
->
comm_intra
);
for
(
int
i
=
0
;
i
<
comm
->
nvsize
;
i
++
)
{
if
(
i
!=
comm
->
nvrank
)
{
NVTE_CHECK_CUDA
(
cudaIpcOpenMemHandle
(
&
(
comm
->
peer_ptr
[
hndl
][
i
]),
tmp
[
i
],
// NOLINT(*)
cudaIpcMemLazyEnablePeerAccess
));
}
}
comm
->
peer_ptr
[
hndl
][
comm
->
nvrank
]
=
*
gpubuff
;
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
NVTE_CHECK_CUDA
(
cudaMemcpy
(
reinterpret_cast
<
char
*>
(
comm
->
gpu_ptrs
)
+
(
hndl
*
comm
->
nvsize
*
sizeof
(
void
*
)),
comm
->
peer_ptr
[
hndl
],
comm
->
nvsize
*
sizeof
(
void
*
),
cudaMemcpyHostToDevice
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
free
(
tmp
);
#if CUDART_VERSION >= 12010
}
#endif
comm
->
mem_size
[
hndl
]
=
aligned_size
;
comm
->
mem_ptr
[
hndl
]
=
*
gpubuff
;
return
comm
->
free_region
++
;
}
Prev
1
…
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment