Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
75e9ef24
Commit
75e9ef24
authored
May 27, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.3' of
http://10.16.6.30/dcutoolkit/deeplearing/TransformerEngine
parents
5753c5bb
291fcf52
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
117 additions
and
19 deletions
+117
-19
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+91
-7
transformer_engine/common/gemm/cublaslt_gemm.cu
transformer_engine/common/gemm/cublaslt_gemm.cu
+9
-9
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+2
-0
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+15
-3
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
75e9ef24
...
@@ -45,6 +45,21 @@ bool ubuf_built_with_mpi() {
...
@@ -45,6 +45,21 @@ bool ubuf_built_with_mpi() {
#endif
#endif
}
}
static
inline
int
getIntEnv
(
const
char
*
name
,
int
defval
,
int
minval
)
{
int
val
=
defval
;
const
char
*
env
=
std
::
getenv
(
name
);
if
(
env
!=
nullptr
&&
env
[
0
]
!=
'\0'
)
{
val
=
atoi
(
env
);
if
(
val
<
minval
)
{
val
=
minval
;
}
}
return
val
;
}
CommOverlapCore
::
CommOverlapCore
(
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
CommOverlapCore
::
CommOverlapCore
(
int
myrank
,
int
numranks
,
int
mylocal
,
int
numlocal
,
int
mynode
,
int
numnodes
,
int
tp_size
,
ExtAllgatherOp
allgather_handle
,
int
numnodes
,
int
tp_size
,
ExtAllgatherOp
allgather_handle
,
ExtBarrierOp
barrier_handle
,
int
num_splits
,
int
num_max_streams
,
ExtBarrierOp
barrier_handle
,
int
num_splits
,
int
num_max_streams
,
...
@@ -74,10 +89,41 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
...
@@ -74,10 +89,41 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority
=
gemm_priority
;
_gemm_priority
=
gemm_priority
;
_comm_priority
=
comm_priority
;
_comm_priority
=
comm_priority
;
}
}
int
comm_cu_nums
=
getIntEnv
(
"NVTE_UB_COMM_CU_NUMS"
,
8
,
4
);
unsigned
int
cuMask
[
4
];
unsigned
int
cuMaskSize
=
4
;
if
(
comm_cu_nums
==
4
)
{
cuMask
[
0
]
=
0xfffffff0
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
if
(
comm_cu_nums
==
8
)
{
cuMask
[
0
]
=
0xffffff00
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
if
(
comm_cu_nums
==
16
)
{
cuMask
[
0
]
=
0xffff0000
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
if
(
comm_cu_nums
==
32
)
{
cuMask
[
0
]
=
0x00000000
;
cuMask
[
1
]
=
0xffffffff
;
cuMask
[
2
]
=
0xffffffff
;
cuMask
[
3
]
=
0xffffffff
;
}
else
{
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
}
static
cudaStream_t
compute_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
static
cudaStream_t
compute_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
if
(
compute_streams
[
i
]
==
nullptr
)
{
if
(
compute_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
_gemm_priority
));
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
_gemm_priority
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
hipExtStreamCreateWithCUMask
(
&
compute_streams
[
i
],
cuMaskSize
,
cuMask
));
#endif
}
}
_stream_compute
.
push_back
(
compute_streams
[
i
]);
_stream_compute
.
push_back
(
compute_streams
[
i
]);
}
}
...
@@ -268,6 +314,10 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
...
@@ -268,6 +314,10 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
atomic_gemm
)
{
atomic_gemm
)
{
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
if
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
){
_ub_force_blas_multistream
=
true
;
}
_ub_stream_nums
=
num_max_streams
;
_ub_stream_nums
=
num_max_streams
;
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
...
@@ -282,10 +332,41 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
...
@@ -282,10 +332,41 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
int
comm_cu_nums
=
getIntEnv
(
"NVTE_UB_COMM_CU_NUMS"
,
8
,
4
);
unsigned
int
cuMask
[
4
];
unsigned
int
cuMaskSize
=
4
;
if
(
comm_cu_nums
==
4
)
{
cuMask
[
0
]
=
0x0000000f
;
cuMask
[
1
]
=
0x00000000
;
cuMask
[
2
]
=
0x00000000
;
cuMask
[
3
]
=
0x00000000
;
}
else
if
(
comm_cu_nums
==
8
)
{
cuMask
[
0
]
=
0x000000ff
;
cuMask
[
1
]
=
0x00000000
;
cuMask
[
2
]
=
0x00000000
;
cuMask
[
3
]
=
0x00000000
;
}
else
if
(
comm_cu_nums
==
16
)
{
cuMask
[
0
]
=
0x0000ffff
;
cuMask
[
1
]
=
0x00000000
;
cuMask
[
2
]
=
0x00000000
;
cuMask
[
3
]
=
0x00000000
;
}
else
if
(
comm_cu_nums
==
32
)
{
cuMask
[
0
]
=
0xffffffff
;
cuMask
[
1
]
=
0x00000000
;
cuMask
[
2
]
=
0x00000000
;
cuMask
[
3
]
=
0x00000000
;
}
else
{
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
}
static
cudaStream_t
comm_stream
;
static
cudaStream_t
comm_stream
;
if
(
comm_stream
==
nullptr
)
{
if
(
comm_stream
==
nullptr
)
{
NVTE_CHECK_CUDA
(
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
comm_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
cudaStreamCreateWithPriority
(
&
comm_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
hipExtStreamCreateWithCUMask
(
&
comm_stream
,
cuMaskSize
,
cuMask
));
#endif
}
}
_stream_comm
=
comm_stream
;
_stream_comm
=
comm_stream
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
...
@@ -499,7 +580,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -499,7 +580,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto
bias_chunk
=
maybe_get_bias_chunk
(
0
);
auto
bias_chunk
=
maybe_get_bias_chunk
(
0
);
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
||
_ub_force_blas_multistream
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias_chunk
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
...
@@ -516,7 +597,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -516,7 +597,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk
=
get_tensor_chunk
(
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
||
_ub_force_blas_multistream
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias_chunk
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
...
@@ -572,7 +653,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -572,7 +653,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
||
_ub_force_blas_multistream
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias_chunk
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
...
@@ -631,7 +712,10 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
...
@@ -631,7 +712,10 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
)
{
atomic_gemm
)
{
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
if
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
){
_ub_force_blas_multistream
=
true
;
}
_ub_stream_nums
=
num_max_streams
;
_ub_stream_nums
=
num_max_streams
;
_is_p2p
=
true
;
_is_p2p
=
true
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
...
@@ -902,7 +986,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -902,7 +986,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
||
_ub_force_blas_multistream
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
...
@@ -962,7 +1046,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -962,7 +1046,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
||
_ub_force_blas_multistream
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
...
@@ -1115,7 +1199,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1115,7 +1199,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
if
(
_ub_stream_nums
==
1
)
{
if
(
_ub_stream_nums
==
1
||
_ub_force_blas_multistream
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]);
...
...
transformer_engine/common/gemm/cublaslt_gemm.cu
View file @
75e9ef24
...
@@ -784,17 +784,17 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
...
@@ -784,17 +784,17 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
for
(
int
s
=
0
;
s
<
num_stream_used
;
s
++
)
{
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
compute_streams
[
s
],
cublas_event
[
0
]));
}
}
const
char
*
NVTE_
HIP
BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_
HIP
BLAS_MULSTREAM"
);
const
char
*
NVTE_BLAS_MULSTREAM
=
std
::
getenv
(
"NVTE_FORCE_BLAS_MULSTREAM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
const
char
*
NVTE_FORCE_ROCM_GEMM
=
std
::
getenv
(
"NVTE_FORCE_ROCM_GEMM"
);
bool
NVTE_FORCE_
HIP
BLAS_MULSTREAM
;
bool
NVTE_FORCE_BLAS_MULSTREAM
;
if
(
NVTE_
HIP
BLAS_MULSTREAM
!=
nullptr
&&
NVTE_
HIP
BLAS_MULSTREAM
[
0
]
==
'1'
){
if
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
){
NVTE_FORCE_
HIP
BLAS_MULSTREAM
=
true
;
NVTE_FORCE_BLAS_MULSTREAM
=
true
;
if
((
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
&&
(
NVTE_
HIP
BLAS_MULSTREAM
!=
nullptr
&&
NVTE_
HIP
BLAS_MULSTREAM
[
0
]
==
'1'
))
if
((
NVTE_FORCE_ROCM_GEMM
!=
nullptr
&&
NVTE_FORCE_ROCM_GEMM
[
0
]
==
'1'
)
&&
(
NVTE_BLAS_MULSTREAM
!=
nullptr
&&
NVTE_BLAS_MULSTREAM
[
0
]
==
'1'
))
NVTE_ERROR
(
"NVTE_FORCE_
HIP
BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time."
);
NVTE_ERROR
(
"NVTE_FORCE_BLAS_MULSTREAM and NVTE_FORCE_ROCM_GEMM can't be set at the same time."
);
}
else
{
}
else
{
NVTE_FORCE_
HIP
BLAS_MULSTREAM
=
false
;
NVTE_FORCE_BLAS_MULSTREAM
=
false
;
}
}
if
(
NVTE_FORCE_
HIP
BLAS_MULSTREAM
){
if
(
NVTE_FORCE_BLAS_MULSTREAM
){
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
nvte_cublas_gemm
(
A
[
i
],
B
[
i
],
D
[
i
],
bias
[
i
],
pre_gelu_out
[
i
],
transa
,
transb
,
grad
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
workspace
[
i
%
num_streams
],
accumulate
,
use_split_accumulator
,
math_sm_count
,
...
@@ -838,7 +838,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
...
@@ -838,7 +838,7 @@ void nvte_multi_stream_cublas_batchgemm(const NVTETensor *A, const NVTETensor *B
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_stream_cublas_batchgemm
);
NVTE_API_CALL
(
nvte_multi_stream_cublas_batchgemm
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
int
batch_count
=
getIntEnv
(
"NVTE_MOE_BATCHCOUNT"
,
2
,
1
);
;
int
batch_count
=
getIntEnv
(
"NVTE_MOE_BATCHCOUNT"
,
2
,
1
);
// Inits streams and events (once, globally)
// Inits streams and events (once, globally)
std
::
call_once
(
init_flag_batchgemm
,
init_streams_and_events_batchgemm
);
std
::
call_once
(
init_flag_batchgemm
,
init_streams_and_events_batchgemm
);
...
...
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
75e9ef24
...
@@ -138,6 +138,7 @@ class CommOverlapCore {
...
@@ -138,6 +138,7 @@ class CommOverlapCore {
class
CommOverlapBase
:
public
CommOverlapCore
{
class
CommOverlapBase
:
public
CommOverlapCore
{
protected:
protected:
int
_ub_stream_nums
;
int
_ub_stream_nums
;
bool
_ub_force_blas_multistream
;
int
_rs_kernel_type
;
int
_rs_kernel_type
;
bool
_rs_overlap_first_gemm
;
bool
_rs_overlap_first_gemm
;
cudaStream_t
_stream_comm
;
cudaStream_t
_stream_comm
;
...
@@ -204,6 +205,7 @@ class CommOverlapBase : public CommOverlapCore {
...
@@ -204,6 +205,7 @@ class CommOverlapBase : public CommOverlapCore {
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
protected:
protected:
int
_ub_stream_nums
;
int
_ub_stream_nums
;
bool
_ub_force_blas_multistream
;
bool
_is_reduce_scatter
{
false
};
bool
_is_reduce_scatter
{
false
};
bool
_use_multiatomic_ag
{
false
};
bool
_use_multiatomic_ag
{
false
};
bool
_aggregate
;
bool
_aggregate
;
...
...
transformer_engine/pytorch/module/base.py
View file @
75e9ef24
...
@@ -128,7 +128,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor
...
@@ -128,7 +128,7 @@ def get_dummy_wgrad(shape: list, dtype: torch.dtype, zero=False) -> torch.Tensor
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)].
fill_
(
0
)
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)].
fill_
(
0
)
return
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)].
detach
()
return
_dummy_wgrads
[(
shape
[
0
],
shape
[
1
],
dtype
)].
detach
()
ub_comm_cu_nums
=
int
(
os
.
getenv
(
"NVTE_UB_COMM_CU_NUMS"
,
"8"
))
def
initialize_ub
(
def
initialize_ub
(
shape
:
list
,
shape
:
list
,
tp_size
:
int
,
tp_size
:
int
,
...
@@ -279,12 +279,24 @@ def initialize_ub(
...
@@ -279,12 +279,24 @@ def initialize_ub(
layers_reduce_scatter_overlap
=
[
"proj_fprop"
,
"fc2_fprop"
,
"qkv_wgrad"
,
"fc1_wgrad"
]
layers_reduce_scatter_overlap
=
[
"proj_fprop"
,
"fc2_fprop"
,
"qkv_wgrad"
,
"fc1_wgrad"
]
dgrad_reduce_scatter_overlap
=
[
"qkv_dgrad"
,
"fc1_dgrad"
]
dgrad_reduce_scatter_overlap
=
[
"qkv_dgrad"
,
"fc1_dgrad"
]
# Default overlap methods for layers
# Default overlap methods for layers
if
bool
(
int
(
os
.
getenv
(
"NVTE_NO_PIPELINE_OVERLAP"
,
"0"
))):
if
bool
(
int
(
os
.
getenv
(
"NVTE_
PROJ_NO_PIPELINE_OVERLAP"
,
"0"
)))
and
bool
(
int
(
os
.
getenv
(
"NVTE_FC2_
NO_PIPELINE_OVERLAP"
,
"0"
))):
methods
=
{
methods
=
{
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
,
"proj_fprop"
,
"fc2_fprop"
],
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
,
"proj_fprop"
,
"fc2_fprop"
],
"pipeline"
:
[],
"pipeline"
:
[],
"bulk"
:
[
"qkv_dgrad"
,
"qkv_wgrad"
,
"fc1_dgrad"
,
"fc1_wgrad"
],
"bulk"
:
[
"qkv_dgrad"
,
"qkv_wgrad"
,
"fc1_dgrad"
,
"fc1_wgrad"
],
}
}
elif
bool
(
int
(
os
.
getenv
(
"NVTE_PROJ_NO_PIPELINE_OVERLAP"
,
"0"
))):
methods
=
{
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
,
"proj_fprop"
],
"pipeline"
:
[
"fc2_fprop"
],
"bulk"
:
[
"qkv_dgrad"
,
"qkv_wgrad"
,
"fc1_dgrad"
,
"fc1_wgrad"
],
}
elif
bool
(
int
(
os
.
getenv
(
"NVTE_FC2_NO_PIPELINE_OVERLAP"
,
"0"
))):
methods
=
{
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
,
"fc2_fprop"
],
"pipeline"
:
[
"proj_fprop"
],
"bulk"
:
[
"qkv_dgrad"
,
"qkv_wgrad"
,
"fc1_dgrad"
,
"fc1_wgrad"
],
}
else
:
else
:
methods
=
{
methods
=
{
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
],
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
],
...
@@ -313,7 +325,7 @@ def initialize_ub(
...
@@ -313,7 +325,7 @@ def initialize_ub(
default_cfg
=
{
default_cfg
=
{
"method"
:
method
,
"method"
:
method
,
"is_reduce_scatter"
:
is_reduce_scatter
,
"is_reduce_scatter"
:
is_reduce_scatter
,
"num_sm"
:
1
if
method
==
"ring_exchange"
else
8
,
"num_sm"
:
1
if
method
==
"ring_exchange"
else
ub_comm_cu_nums
,
"cga_size"
:
1
if
method
==
"ring_exchange"
else
2
,
"cga_size"
:
1
if
method
==
"ring_exchange"
else
2
,
"set_sm_margin"
:
not
method
==
"ring_exchange"
,
"set_sm_margin"
:
not
method
==
"ring_exchange"
,
"num_splits"
:
tp_size
if
method
==
"ring_exchange"
else
4
,
"num_splits"
:
tp_size
if
method
==
"ring_exchange"
else
4
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment