Commit dfd264c3 authored by yuguo's avatar yuguo
Browse files

[DCU] tmp fix p2p overlap

parent 24b1c0ff
...@@ -16,6 +16,7 @@ from functools import partial, reduce ...@@ -16,6 +16,7 @@ from functools import partial, reduce
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed.elastic.multiprocessing.errors import record from torch.distributed.elastic.multiprocessing.errors import record
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine.pytorch as te import transformer_engine.pytorch as te
import transformer_engine.pytorch.cpp_extensions as tex import transformer_engine.pytorch.cpp_extensions as tex
...@@ -311,6 +312,7 @@ def _main(opts): ...@@ -311,6 +312,7 @@ def _main(opts):
helper, helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
opts.comm_type, opts.comm_type,
num_max_streams=2 if IS_HIP_EXTENSION else 3,
set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic, set_sm_margin=opts.comm_type == tex.CommOverlapType.RS or opts.atomic,
atomic_gemm=opts.atomic, atomic_gemm=opts.atomic,
aggregate=opts.aggregate, aggregate=opts.aggregate,
...@@ -322,6 +324,7 @@ def _main(opts): ...@@ -322,6 +324,7 @@ def _main(opts):
buffer_dtype, buffer_dtype,
helper, helper,
tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE) tp_size, # Tensor-parallel group size (may be different than LOCAL_SIZE)
num_max_streams=1 if IS_HIP_EXTENSION else 3,
atomic_gemm=opts.atomic, atomic_gemm=opts.atomic,
) )
) )
...@@ -398,7 +401,7 @@ def _main(opts): ...@@ -398,7 +401,7 @@ def _main(opts):
) )
# Allocate cuBLAS workspace # Allocate cuBLAS workspace
workspace_size = 3 * get_cublas_workspace_size_bytes() workspace_size = 2 * get_cublas_workspace_size_bytes()
workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda")
# Gather global tensors and calculate reference result (need these first for Fp8 scales) # Gather global tensors and calculate reference result (need these first for Fp8 scales)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
# mpirun -np 4 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=1024 --batch-size=2 --num-heads=16 --head-dim=48 --comm-type=AG --p2p # mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=2048 --batch-size=2 --num-heads=96 --head-dim=128 --comm-type=AG --p2p
# mpirun -np 8 --allow-run-as-root --oversubscribe --quiet python3 /home/TransformerEngine/tests/pytorch/distributed/run_gemm_with_overlap.py --check-numerics --seed=42 --seq-length=2048 --batch-size=2 --num-heads=96 --head-dim=128 --comm-type=RS --p2p
import os import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
...@@ -19,10 +20,10 @@ if torch.cuda.device_count() < 2: ...@@ -19,10 +20,10 @@ if torch.cuda.device_count() < 2:
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
RNG_SEED: int = 42 RNG_SEED: int = 42
SEQ_LENGTH: int = 1024 SEQ_LENGTH: int = 2048
BATCH_SIZE: int = 2 BATCH_SIZE: int = 2
NUM_HEADS: int = 16 NUM_HEADS: int = 96
HEAD_DIM: int = 48 HEAD_DIM: int = 128
TE_LAYERS = [ TE_LAYERS = [
te.Linear, te.Linear,
te.LayerNormLinear, te.LayerNormLinear,
......
...@@ -430,7 +430,22 @@ struct hip_f8 { ...@@ -430,7 +430,22 @@ struct hip_f8 {
#endif // #ifdef __gfx942__ #endif // #ifdef __gfx942__
// convert to hip_bfloat16 // convert to hip_bfloat16
explicit inline HIP_HOST_DEVICE operator __hip_bfloat16() const; explicit inline HIP_HOST_DEVICE
operator __hip_bfloat16() const {
if (T == hip_f8_type::bf8) {
if (get_hip_f8_bias_mode()) {
return static_cast<__hip_bfloat16>(hip_f8_impl::cast_from_f8<2, 5, float, true/*negative_zero_nan*/>(data));
} else {
return static_cast<__hip_bfloat16>(hip_f8_impl::cast_from_f8<2, 5, float, false/*negative_zero_nan*/>(data));
}
} else /* fp8*/ {
if (get_hip_f8_bias_mode()) {
return static_cast<__hip_bfloat16>(hip_f8_impl::cast_from_f8<3, 4, float, true/*negative_zero_nan*/>(data));
} else {
return static_cast<__hip_bfloat16>(hip_f8_impl::cast_from_f8<3, 4, float, false/*negative_zero_nan*/>(data));
}
}
}
// check for zero // check for zero
inline HIP_HOST_DEVICE bool is_zero() const { inline HIP_HOST_DEVICE bool is_zero() const {
......
...@@ -11,6 +11,9 @@ ...@@ -11,6 +11,9 @@
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h> #include <cuda_bf16.h>
#define half_dtype nv_bfloat16 #define half_dtype nv_bfloat16
#elif defined(__HIP_PLATFORM_AMD__)
#include <cuda_bf16.h>
#define half_dtype __hip_bfloat16
#else #else
#include <cuda_fp16.h> #include <cuda_fp16.h>
#define half_dtype half #define half_dtype half
...@@ -358,9 +361,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -358,9 +361,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
reduceidptr = myptr - NVTE_MAX_OPS; // +op; reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1; reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset; flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset;
__threadfence_system();
if (blockIdx.x == 0) flagptr[physgpu] = reduce_id; if (blockIdx.x == 0) flagptr[physgpu] = reduce_id;
__threadfence_system();
volatile int *flag = (volatile int *)&(myptr[targetgpu]); volatile int *flag = (volatile int *)&(myptr[targetgpu]);
__threadfence_system();
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]); userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
__threadfence_system();
clock_t s = clock64(); clock_t s = clock64();
while (CHECK_IDS(*flag, reduce_id)) { while (CHECK_IDS(*flag, reduce_id)) {
if (CHECK_TIMEOUT(s, ub_timeout)) { if (CHECK_TIMEOUT(s, ub_timeout)) {
...@@ -404,8 +411,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ ...@@ -404,8 +411,9 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
(reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum; (reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum;
} }
__threadfence_system();
if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id; if (threadIdx.x == 0 && lastSM) *reduceidptr = reduce_id;
__threadfence_system();
} // fp16 reduce-scatter kernel (out of place) } // fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010 #if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
...@@ -2082,7 +2090,11 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>( ...@@ -2082,7 +2090,11 @@ template void reducescatter2_userbuff_strided_multiatomic_fp8<__nv_fp8_e5m2>(
#endif #endif
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) { __global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
#ifdef __HIP_PLATFORM_AMD__
*flagptr = *flagptr + 1;
#else
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
#endif
} }
__global__ void kuserbuffers_inc(int *id) { atomicAdd(id, 1); } __global__ void kuserbuffers_inc(int *id) { atomicAdd(id, 1); }
...@@ -2153,10 +2165,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -2153,10 +2165,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads(); __syncthreads();
if (threadIdx.x) return; if (threadIdx.x) return;
__threadfence_system(); __threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*flagptr = *flagptr + 1;
#else
atomicAdd_system(flagptr, atomicAdd_system(flagptr,
1); // otherwise need local SM sync before sending flag 1); // otherwise need local SM sync before sending flag
#endif
} else { // 0 bytes and 1 SM only } else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*flagptr = *flagptr + 1;
#else
atomicAdd_system(flagptr, 1); atomicAdd_system(flagptr, 1);
#endif
} }
} }
...@@ -2215,10 +2235,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -2215,10 +2235,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads(); __syncthreads();
if (threadIdx.x) return; if (threadIdx.x) return;
__threadfence_system(); __threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag 1); // otherwise need local SM sync before sending flag
#endif
} else { // 0 bytes and 1 SM only } else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, 1); atomicAdd_system(send_flagptr, 1);
#endif
} }
if (blockIdx.x == 0 && threadIdx.x == 0) { if (blockIdx.x == 0 && threadIdx.x == 0) {
...@@ -2273,10 +2301,18 @@ __global__ void __launch_bounds__(MAX_THREADS) ...@@ -2273,10 +2301,18 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads(); __syncthreads();
if (threadIdx.x) return; if (threadIdx.x) return;
__threadfence_system(); __threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag 1); // otherwise need local SM sync before sending flag
#endif
} else { // 0 bytes and 1 SM only } else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, 1); atomicAdd_system(send_flagptr, 1);
#endif
} }
if (blockIdx.x == 0 && threadIdx.x == 0) { if (blockIdx.x == 0 && threadIdx.x == 0) {
...@@ -2346,11 +2382,19 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat ...@@ -2346,11 +2382,19 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
__syncthreads(); __syncthreads();
if (!threadIdx.x) { if (!threadIdx.x) {
__threadfence_system(); __threadfence_system();
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, atomicAdd_system(send_flagptr,
1); // otherwise need local SM sync before sending flag 1); // otherwise need local SM sync before sending flag
#endif
} }
} else { // 0 bytes and 1 SM only } else { // 0 bytes and 1 SM only
#ifdef __HIP_PLATFORM_AMD__
*send_flagptr = *send_flagptr + 1;
#else
atomicAdd_system(send_flagptr, 1); atomicAdd_system(send_flagptr, 1);
#endif
} }
// wait for message to arrive. // wait for message to arrive.
...@@ -2422,6 +2466,9 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat ...@@ -2422,6 +2466,9 @@ __global__ void __launch_bounds__(MAX_THREADS) kuserbuffers_pushsendrecv_multiat
void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm, const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) { const int peer, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
int peerlocal = peer % comm->nvsize; int peerlocal = peer % comm->nvsize;
void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0); void *flagptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 0);
// void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1); // void *ce_send_start_ptr = GET_SEND_PTR_BY_INDEX(peerlocal, comm, dsthandler, 1);
...@@ -2453,11 +2500,17 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds ...@@ -2453,11 +2500,17 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs)); cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
} }
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
} }
void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset, void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size_t send_offset,
const size_t recv_offset, const size_t bytes, communicator *comm, const size_t recv_offset, const size_t bytes, communicator *comm,
const int send_peer, const int recv_peer, cudaStream_t stream) { const int send_peer, const int recv_peer, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
int send_peerlocal = send_peer % comm->nvsize; int send_peerlocal = send_peer % comm->nvsize;
int recv_peerlocal = recv_peer % comm->nvsize; int recv_peerlocal = recv_peer % comm->nvsize;
...@@ -2507,12 +2560,18 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size ...@@ -2507,12 +2560,18 @@ void userbuffers_sendrecv(const int srchandler, const int dsthandler, const size
reinterpret_cast<void *>(&arg15)}; reinterpret_cast<void *>(&arg15)};
NVTE_CHECK_CUDA( NVTE_CHECK_CUDA(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs)); cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv), kernelArgs));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
} }
void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
const size_t send_offset, const size_t recv_offset, const size_t send_offset, const size_t recv_offset,
const size_t bytes, communicator *comm, const int send_peer, const size_t bytes, communicator *comm, const int send_peer,
const int recv_peer, void *counters, cudaStream_t stream) { const int recv_peer, void *counters, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
assert(comm->push && comm->use_ce == 0); assert(comm->push && comm->use_ce == 0);
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
...@@ -2564,6 +2623,9 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler, ...@@ -2564,6 +2623,9 @@ void userbuffers_sendrecv_atomic(const int srchandler, const int dsthandler,
reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)}; reinterpret_cast<void *>(&arg15), reinterpret_cast<void *>(&arg16)};
NVTE_CHECK_CUDA(cudaLaunchKernelExC( NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), kernelArgs)); &cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_atomic), kernelArgs));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
} }
void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler, void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler,
...@@ -2571,6 +2633,9 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler ...@@ -2571,6 +2633,9 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
const size_t bytes, communicator *comm, const int send_peer, const size_t bytes, communicator *comm, const int send_peer,
const int recv_peer, const int nchunks, void *counters, const int recv_peer, const int nchunks, void *counters,
bool shuffle, cudaStream_t stream) { bool shuffle, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
assert(comm->push && comm->use_ce == 0); assert(comm->push && comm->use_ce == 0);
// CE is not supported // CE is not supported
...@@ -2610,11 +2675,17 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler ...@@ -2610,11 +2675,17 @@ void userbuffers_sendrecv_multiatomic(const int srchandler, const int dsthandler
reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)}; reinterpret_cast<void *>(&arg17), reinterpret_cast<void *>(&arg18)};
NVTE_CHECK_CUDA(cudaLaunchKernelExC( NVTE_CHECK_CUDA(cudaLaunchKernelExC(
&cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs)); &cfg, reinterpret_cast<void *>(kuserbuffers_pushsendrecv_multiatomic), kernelArgs));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
} }
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler, void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm, const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) { const int peer, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
int peerlocal = peer % comm->nvsize; int peerlocal = peer % comm->nvsize;
void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0); void *flagptr = GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 0);
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0); bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
...@@ -2648,6 +2719,9 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds ...@@ -2648,6 +2719,9 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds
GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2)
: nullptr)); : nullptr));
} }
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA(cudaStreamSynchronize(stream));
#endif
} }
// producer // producer
...@@ -2846,7 +2920,11 @@ __global__ void __launch_bounds__(MAX_THREADS / 4) ...@@ -2846,7 +2920,11 @@ __global__ void __launch_bounds__(MAX_THREADS / 4)
} }
void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
constexpr int nvec = 8;
#else
constexpr int nvec = 32; constexpr int nvec = 32;
#endif
assert(input_size % nvec == 0); assert(input_size % nvec == 0);
const int num_aligned_elements_per_input = input_size / nvec; const int num_aligned_elements_per_input = input_size / nvec;
const int tot_input_size = input_size * num_inputs; const int tot_input_size = input_size * num_inputs;
......
...@@ -15,7 +15,11 @@ ...@@ -15,7 +15,11 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" #include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3 #define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace transformer_engine { namespace transformer_engine {
......
...@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = [] ...@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace = [] _multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None _cublas_workspace = None
_ub_communicators = None _ub_communicators = None
_NUM_MAX_UB_STREAMS = 3 _NUM_MAX_UB_STREAMS = 2 if IS_HIP_EXTENSION else 3
_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
...@@ -357,7 +357,7 @@ def initialize_ub( ...@@ -357,7 +357,7 @@ def initialize_ub(
helper, # Helper for torch.distributed callbacks during bootstrapping helper, # Helper for torch.distributed callbacks during bootstrapping
tp_size, # Tensor-parallel group size (may be different than local_size) tp_size, # Tensor-parallel group size (may be different than local_size)
num_splits=num_splits, num_splits=num_splits,
num_max_streams=_NUM_MAX_UB_STREAMS, num_max_streams=_NUM_MAX_UB_STREAMS - 1 if IS_HIP_EXTENSION else _NUM_MAX_UB_STREAMS,
comm_cga_size=cga_size, comm_cga_size=cga_size,
num_comm_sm=num_sm, num_comm_sm=num_sm,
set_sm_margin=set_sm_margin, set_sm_margin=set_sm_margin,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment