Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
07b750a2
Commit
07b750a2
authored
Apr 16, 2025
by
yuguo
Browse files
[DCU] tmp fix overlap allmethod
parent
8fb50d09
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
70 deletions
+43
-70
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+5
-3
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+1
-0
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+37
-67
No files found.
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
07b750a2
...
...
@@ -212,7 +212,7 @@ def _parse_args(argv=None, namespace=None):
parser
.
add_argument
(
"--benchmark-iter"
,
type
=
int
,
default
=
10
0
,
default
=
2
0
,
help
=
"Number of iterations for benchmarking perf."
,
)
parser
.
add_argument
(
...
...
@@ -376,6 +376,8 @@ def _train(opts):
ub_cfgs
=
{
"qkv_dgrad"
:
{
"method"
:
"ring_exchange"
},
"fc1_dgrad"
:
{
"method"
:
"ring_exchange"
},
"proj_fprop"
:
{
"method"
:
"ring_exchange"
},
"fc2_fprop"
:
{
"method"
:
"ring_exchange"
},
}
te
.
module
.
base
.
initialize_ub
(
[
opts
.
seq_length
*
opts
.
batch_size
,
opts
.
num_heads
*
opts
.
head_dim
],
...
...
@@ -498,11 +500,11 @@ def _train(opts):
if
opts
.
benchmark
:
# Warmup to not profile CPU overhead
for
_
in
range
(
10
0
):
for
_
in
range
(
2
0
):
if
opts
.
use_cuda_graphs
:
test_graph
.
replay
()
else
:
test_out
=
run_fwd_bwd
(
test
_model
,
test
_x
)
test_out
=
run_fwd_bwd
(
ref
_model
,
ref
_x
)
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
for
_
in
range
(
opts
.
benchmark_iter
):
if
opts
.
use_cuda_graphs
:
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
07b750a2
...
...
@@ -880,6 +880,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
_ubufs
[
_tp_id
].
numel
()
*
_ubufs
[
_tp_id
].
element_size
(),
cudaMemcpyDeviceToDevice
,
_stream_send
[
0
]));
}
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
}
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
07b750a2
...
...
@@ -292,11 +292,17 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
(
myptr
[
targetgpu
]);
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
...
@@ -309,7 +315,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
...
...
@@ -340,8 +348,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
userptr
[
myrank
][
mylineoffset
+
line
]
=
sum
;
}
__threadfence_system
();
if
(
threadIdx
.
x
==
0
&&
lastSM
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
}
// fp16 inplace reduce-scatter kernel
template
<
int
RANKS
>
...
...
@@ -359,7 +368,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
...
...
@@ -380,7 +391,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
...
...
@@ -1237,9 +1250,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
clock_t
s
=
clock64
();
}
...
...
@@ -1275,7 +1292,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__
int
lastSM
;
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
else
...
...
@@ -1283,9 +1302,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
__syncthreads
();
if
(
lastSM
&&
threadIdx
.
x
<
RANKS
)
{
__threadfence_system
();
if
(
threadIdx
.
x
==
0
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
myptr
[
targetgpu
];
__threadfence_system
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
...
@@ -1314,9 +1337,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
}
__syncthreads
();
localptr
=
userptr
[
myrank
];
...
...
@@ -1370,7 +1397,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__
int
lastSM
;
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
else
...
...
@@ -1378,9 +1407,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
__syncthreads
();
if
(
lastSM
&&
threadIdx
.
x
<
RANKS
)
{
__threadfence_system
();
if
(
threadIdx
.
x
==
0
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
myptr
[
targetgpu
];
__threadfence_system
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
...
@@ -2090,11 +2123,7 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
#endif
__global__
void
kuserbuffers_pullsend
(
int
myrank
,
int
peer
,
int
*
send_id
,
int
*
flagptr
)
{
#ifdef __HIP_PLATFORM_AMD__
*
flagptr
=
*
flagptr
+
1
;
#else
atomicAdd_system
(
flagptr
,
1
);
#endif
}
__global__
void
kuserbuffers_inc
(
int
*
id
)
{
atomicAdd
(
id
,
1
);
}
...
...
@@ -2165,18 +2194,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
)
return
;
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
flagptr
=
*
flagptr
+
1
;
#else
atomicAdd_system
(
flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
flagptr
=
*
flagptr
+
1
;
#else
atomicAdd_system
(
flagptr
,
1
);
#endif
__threadfence_system
();
}
}
...
...
@@ -2188,7 +2210,9 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe
int
*
ce_start_ptr
,
int
*
ce_end_ptr
)
{
const
int
signal_id
=
(
*
recv_id
)
+
adder
;
*
recv_id
=
signal_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
flagptr
;
__threadfence_system
();
if
(
*
flag
>=
signal_id
)
return
;
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
signal_id
))
{
...
...
@@ -2235,18 +2259,10 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
)
return
;
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
#endif
}
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
...
...
@@ -2301,18 +2317,10 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
if
(
threadIdx
.
x
)
return
;
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
#endif
}
if
(
blockIdx
.
x
==
0
&&
threadIdx
.
x
==
0
)
{
...
...
@@ -2382,19 +2390,11 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
__syncthreads
();
if
(
!
threadIdx
.
x
)
{
__threadfence_system
();
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
// otherwise need local SM sync before sending flag
#endif
}
}
else
{
// 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*
send_flagptr
=
*
send_flagptr
+
1
;
#else
atomicAdd_system
(
send_flagptr
,
1
);
#endif
}
// wait for message to arrive.
...
...
@@ -2466,9 +2466,6 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
void
userbuffers_send
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
peer
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
int
peerlocal
=
peer
%
comm
->
nvsize
;
void
*
flagptr
=
GET_SEND_PTR_BY_INDEX
(
peerlocal
,
comm
,
dsthandler
,
0
);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1);
...
...
@@ -2500,17 +2497,11 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsend
),
kernelArgs
));
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_sendrecv
(
const
int
srchandler
,
const
int
dsthandler
,
const
size_t
send_offset
,
const
size_t
recv_offset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
send_peer
,
const
int
recv_peer
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
bool
signalonly
=
(
bytes
/
16
==
0
)
||
(
comm
->
use_ce
!=
0
);
int
send_peerlocal
=
send_peer
%
comm
->
nvsize
;
int
recv_peerlocal
=
recv_peer
%
comm
->
nvsize
;
...
...
@@ -2560,18 +2551,12 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast
<
void
*>
(
&
arg15
)};
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsendrecv
),
kernelArgs
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_sendrecv_atomic
(
const
int
srchandler
,
const
int
dsthandler
,
const
size_t
send_offset
,
const
size_t
recv_offset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
send_peer
,
const
int
recv_peer
,
void
*
counters
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
assert
(
comm
->
push
&&
comm
->
use_ce
==
0
);
bool
signalonly
=
(
bytes
/
16
==
0
)
||
(
comm
->
use_ce
!=
0
);
...
...
@@ -2623,9 +2608,6 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast
<
void
*>
(
&
arg15
),
reinterpret_cast
<
void
*>
(
&
arg16
)};
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsendrecv_atomic
),
kernelArgs
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_sendrecv_multiatomic
(
const
int
srchandler
,
const
int
dsthandler
,
...
...
@@ -2633,9 +2615,6 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
const
size_t
bytes
,
communicator
*
comm
,
const
int
send_peer
,
const
int
recv_peer
,
const
int
nchunks
,
void
*
counters
,
bool
shuffle
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
assert
(
comm
->
push
&&
comm
->
use_ce
==
0
);
// CE is not supported
...
...
@@ -2675,17 +2654,11 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast
<
void
*>
(
&
arg17
),
reinterpret_cast
<
void
*>
(
&
arg18
)};
NVTE_CHECK_CUDA
(
cudaLaunchKernelExC
(
&
cfg
,
reinterpret_cast
<
void
*>
(
kuserbuffers_pushsendrecv_multiatomic
),
kernelArgs
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
void
userbuffers_recv
(
const
int
srchandler
,
const
size_t
srcoffset
,
const
int
dsthandler
,
const
size_t
dstoffset
,
const
size_t
bytes
,
communicator
*
comm
,
const
int
peer
,
cudaStream_t
stream
)
{
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
int
peerlocal
=
peer
%
comm
->
nvsize
;
void
*
flagptr
=
GET_RECV_PTR_BY_INDEX
(
peer
,
comm
,
dsthandler
,
0
);
bool
signalonly
=
(
bytes
/
16
==
0
)
||
(
comm
->
use_ce
!=
0
);
...
...
@@ -2719,9 +2692,6 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
GET_RECV_PTR_BY_INDEX
(
peer
,
comm
,
dsthandler
,
2
)
:
nullptr
));
}
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
cudaStreamSynchronize
(
stream
));
#endif
}
// producer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment