Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
dfd264c3
"...AutoBuildImmortalWrt.git" did not exist on "33dea230e6682e5ae9defaf1c5a6d4540440d973"
Commit
dfd264c3
authored
Apr 11, 2025
by
yuguo
Browse files
[DCU] tmp fix p2p overlap
parent
24b1c0ff
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
110 additions
and
9 deletions
+110
-9
tests/pytorch/distributed/run_gemm_with_overlap.py
tests/pytorch/distributed/run_gemm_with_overlap.py
+4
-1
tests/pytorch/distributed/test_comm_gemm_overlap.py
tests/pytorch/distributed/test_comm_gemm_overlap.py
+5
-4
transformer_engine/common/amd_detail/hip_float8.h
transformer_engine/common/amd_detail/hip_float8.h
+16
-1
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+79
-1
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+4
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+2
-2
No files found.
tests/pytorch/distributed/run_gemm_with_overlap.py
View file @
dfd264c3
...
...
@@ -16,6 +16,7 @@ from functools import partial, reduce
import
torch
import
torch.distributed
as
dist
from
torch.distributed.elastic.multiprocessing.errors
import
record
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch.cpp_extensions
as
tex
...
...
@@ -311,6 +312,7 @@ def _main(opts):
helper
,
tp_size
,
# Tensor-parallel group size (may be different than LOCAL_SIZE)
opts
.
comm_type
,
num_max_streams
=
2
if
IS_HIP_EXTENSION
else
3
,
set_sm_margin
=
opts
.
comm_type
==
tex
.
CommOverlapType
.
RS
or
opts
.
atomic
,
atomic_gemm
=
opts
.
atomic
,
aggregate
=
opts
.
aggregate
,
...
...
@@ -322,6 +324,7 @@ def _main(opts):
buffer_dtype
,
helper
,
tp_size
,
# Tensor-parallel group size (may be different than LOCAL_SIZE)
num_max_streams
=
1
if
IS_HIP_EXTENSION
else
3
,
atomic_gemm
=
opts
.
atomic
,
)
)
...
...
@@ -398,7 +401,7 @@ def _main(opts):
)
# Allocate cuBLAS workspace
workspace_size
=
3
*
get_cublas_workspace_size_bytes
()
workspace_size
=
2
*
get_cublas_workspace_size_bytes
()
workspace
=
torch
.
empty
(
workspace_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
# Gather global tensors and calculate reference result (need these first for Fp8 scales)
...
...
tests/pytorch/distributed/test_comm_gemm_overlap.py
View file @
dfd264c3
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# mpirun -np 4 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=1024 --batch-size=2 --num-heads=16 --head-dim=48 --comm-type=AG --p2p
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=2048 --batch-size=2 --num-heads=96 --head-dim=128 --comm-type=AG --p2p
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=2048 --batch-size=2 --num-heads=96 --head-dim=128 --comm-type=RS --p2p
import
os
import
subprocess
from
pathlib
import
Path
...
...
@@ -19,10 +20,10 @@ if torch.cuda.device_count() < 2:
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
RNG_SEED
:
int
=
42
SEQ_LENGTH
:
int
=
1024
SEQ_LENGTH
:
int
=
2048
BATCH_SIZE
:
int
=
2
NUM_HEADS
:
int
=
1
6
HEAD_DIM
:
int
=
4
8
NUM_HEADS
:
int
=
9
6
HEAD_DIM
:
int
=
12
8
TE_LAYERS
=
[
te
.
Linear
,
te
.
LayerNormLinear
,
...
...
transformer_engine/common/amd_detail/hip_float8.h
View file @
dfd264c3
...
...
@@ -430,7 +430,22 @@ struct hip_f8 {
#endif // #ifdef __gfx942__
// convert to hip_bfloat16
explicit
inline
HIP_HOST_DEVICE
operator
__hip_bfloat16
()
const
;
explicit
inline
HIP_HOST_DEVICE
operator
__hip_bfloat16
()
const
{
if
(
T
==
hip_f8_type
::
bf8
)
{
if
(
get_hip_f8_bias_mode
())
{
return
static_cast
<
__hip_bfloat16
>
(
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
true
/*negative_zero_nan*/
>
(
data
));
}
else
{
return
static_cast
<
__hip_bfloat16
>
(
hip_f8_impl
::
cast_from_f8
<
2
,
5
,
float
,
false
/*negative_zero_nan*/
>
(
data
));
}
}
else
/* fp8*/
{
if
(
get_hip_f8_bias_mode
())
{
return
static_cast
<
__hip_bfloat16
>
(
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
true
/*negative_zero_nan*/
>
(
data
));
}
else
{
return
static_cast
<
__hip_bfloat16
>
(
hip_f8_impl
::
cast_from_f8
<
3
,
4
,
float
,
false
/*negative_zero_nan*/
>
(
data
));
}
}
}
// check for zero
inline
HIP_HOST_DEVICE
bool
is_zero
()
const
{
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
dfd264c3
...
...
@@ -11,6 +11,9 @@
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#define half_dtype nv_bfloat16
#elif defined(__HIP_PLATFORM_AMD__)
#include <cuda_bf16.h>
#define half_dtype __hip_bfloat16
#else
#include <cuda_fp16.h>
#define half_dtype half
...
...
@@ -358,9 +361,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
reduce_id
=
(
*
reduceidptr
)
+
1
;
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
(
myptr
[
targetgpu
]);
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
...
@@ -404,8 +411,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
(
reinterpret_cast
<
int4
*>
(
outbuf
))[(
line
/
rowlines
)
*
skiplines
+
(
line
%
rowlines
)]
=
sum
;
}
__threadfence_system
();
if
(
threadIdx
.
x
==
0
&&
lastSM
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
}
// fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
...
...
@@ -2082,7 +2090,11 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
#endif
__global__
void
kuserbuffers_pullsend
(
int
myrank
,
int
peer
,
int
*
send_id
,
int
*
flagptr
)
{
#ifdef __HIP_PLATFORM_AMD__
*
flagptr
=
*
flagptr
+
1
;
#else
atomicAdd_system
(
flagptr
,
1
);
#endif
}
__global__
void
kuserbuffers_inc
(
int
*
id
)
{
atomicAdd
(
id
,
1
);
}
...
...
@@ -2153,10 +2165,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
)
return
;
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
flagptr
=
*
flagptr
+
1
;
#else
atomicAdd_system
(
flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
flagptr
=
*
flagptr
+
1
;
#else
atomicAdd_system
(
flagptr
,
1
);
#endif
}
}
...
...
@@ -2215,10 +2235,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
)
return
;
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
#endif
}
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
...
...
@@ -2273,10 +2301,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
)
return
;
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
#endif
}
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
...
...
@@ -2346,11 +2382,19 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
__syncthreads
();
if
(
!
threadIdx
.
x
)
{
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
#endif
}
// wait for message to arrive.
...
...
@@ -2422,6 +2466,9 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
void
userbuffers_send
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
peer
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
int
peerlocal
=
peer
%
comm
->
nvsize
;
void
*
flagptr
=
GET_SEND_PTR_BY_INDEX
(
peerlocal
,
comm
,
dsthandler
,
0
);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1);
...
...
@@ -2453,11 +2500,17 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsend
),
kernelArgs
));
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_sendrecv
(
const
int
srchandler
,
const
int
dsthandler
,
const
size_t
send_offset
,
const
size_t
recv_offset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
send_peer
,
const
int
recv_peer
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
bool
signalonly
=
(
bytes
/
16
==
0
)
||
(
comm
->
use_ce
!=
0
);
int
send_peerlocal
=
send_peer
%
comm
->
nvsize
;
int
recv_peerlocal
=
recv_peer
%
comm
->
nvsize
;
...
...
@@ -2507,12 +2560,18 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast
<
void
*>
(
&
arg15
)};
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsendrecv
),
kernelArgs
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_sendrecv_atomic
(
const
int
srchandler
,
const
int
dsthandler
,
const
size_t
send_offset
,
const
size_t
recv_offset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
send_peer
,
const
int
recv_peer
,
void
*
counters
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
assert
(
comm
->
push
&&
comm
->
use_ce
==
0
);
bool
signalonly
=
(
bytes
/
16
==
0
)
||
(
comm
->
use_ce
!=
0
);
...
...
@@ -2564,6 +2623,9 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast
<
void
*>
(
&
arg15
),
reinterpret_cast
<
void
*>
(
&
arg16
)};
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsendrecv_atomic
),
kernelArgs
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_sendrecv_multiatomic
(
const
int
srchandler
,
const
int
dsthandler
,
...
...
@@ -2571,6 +2633,9 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
const
size_t
bytes
,
communicator
*
comm
,
const
int
send_peer
,
const
int
recv_peer
,
const
int
nchunks
,
void
*
counters
,
bool
shuffle
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
assert
(
comm
->
push
&&
comm
->
use_ce
==
0
);
// CE is not supported
...
...
@@ -2610,11 +2675,17 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast
<
void
*>
(
&
arg17
),
reinterpret_cast
<
void
*>
(
&
arg18
)};
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsendrecv_multiatomic
),
kernelArgs
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_recv
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
peer
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
int
peerlocal
=
peer
%
comm
->
nvsize
;
void
*
flagptr
=
GET_RECV_PTR_BY_INDEX
(
peer
,
comm
,
dsthandler
,
0
);
bool
signalonly
=
(
bytes
/
16
==
0
)
||
(
comm
->
use_ce
!=
0
);
...
...
@@ -2648,6 +2719,9 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
GET_RECV_PTR_BY_INDEX
(
peer
,
comm
,
dsthandler
,
2
)
:
nullptr
));
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
// producer
...
...
@@ -2846,7 +2920,11 @@ __global__ void __launch_bounds__(MAX_THREADS / 4)
}
void
reduce_bf16
(
void
*
inputs
,
void
*
output
,
int
num_inputs
,
int
input_size
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
constexpr
int
nvec
=
8
;
#else
constexpr
int
nvec
=
32
;
#endif
assert
(
input_size
%
nvec
==
0
);
const
int
num_aligned_elements_per_input
=
input_size
/
nvec
;
const
int
tot_input_size
=
input_size
*
num_inputs
;
...
...
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
dfd264c3
...
...
@@ -15,7 +15,11 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace
transformer_engine
{
...
...
transformer_engine/pytorch/module/base.py
View file @
dfd264c3
...
...
@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
3
_NUM_MAX_UB_STREAMS
=
2
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
...
...
@@ -357,7 +357,7 @@ def initialize_ub(
helper
,
# Helper for torch.distributed callbacks during bootstrapping
tp_size
,
# Tensor-parallel group size (may be different than local_size)
num_splits
=
num_splits
,
num_max_streams
=
_NUM_MAX_UB_STREAMS
,
num_max_streams
=
_NUM_MAX_UB_STREAMS
-
1
if
IS_HIP_EXTENSION
else
_NUM_MAX_UB_STREAMS
,
comm_cga_size
=
cga_size
,
num_comm_sm
=
num_sm
,
set_sm_margin
=
set_sm_margin
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment