Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
caf2fbf2
Commit
caf2fbf2
authored
Apr 25, 2025
by
yuguo
Browse files
[DCU] tp overlap opt
parent
0b0a70a5
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
112 additions
and
45 deletions
+112
-45
setup.py
setup.py
+1
-0
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+11
-5
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+70
-25
transformer_engine/pytorch/cpp_extensions/gemm.py
transformer_engine/pytorch/cpp_extensions/gemm.py
+5
-5
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+25
-10
No files found.
setup.py
View file @
caf2fbf2
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
"""Installation script."""
"""Installation script."""
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# NVTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc pip3 install . -v
# VTE_FRAMEWORK=pytorch NVTE_USE_ROCM=1 NVTE_USE_HIPBLASLT=1 NVTE_USE_ROCBLAS=1 CMAKE_PREFIX_PATH=/opt/dtk/lib/cmake/amd_comgr/ MPI_HOME=/opt/mpi/ NVTE_UB_WITH_MPI=1 CXX=hipcc PYTHONPATH=/home/TransformerEngine/3rdparty/hipify_torch:$PYTHONPATH python3 setup.py bdist_wheel
import
os
import
os
import
sys
import
sys
...
...
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
caf2fbf2
...
@@ -312,6 +312,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
...
@@ -312,6 +312,7 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapBase::bulk_overlap
}
// CommOverlapBase::bulk_overlap
...
@@ -461,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -461,7 +462,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
_stream_compute
[
i
%
_stream_compute
.
size
()]
,
1
,
0
);
NVTE_CHECK_CUDA
(
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
...
@@ -509,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -509,7 +510,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
_stream_compute
[
i
%
_stream_compute
.
size
()]
,
1
,
0
);
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
...
@@ -540,6 +541,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -540,6 +541,7 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapBase::split_overlap_rs
}
// CommOverlapBase::split_overlap_rs
/***************************************************************************************************
/***************************************************************************************************
...
@@ -775,8 +777,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -775,8 +777,10 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
}
}
if
(
_aggregate
)
{
if
(
_aggregate
)
{
const
int
num_steps
=
_tp_size
/
2
;
const
int
num_steps
=
_tp_size
/
2
;
#ifndef __HIP_PLATFORM_AMD__
input_chunk_size
*=
2
;
input_chunk_size
*=
2
;
output_chunk_size
*=
2
;
output_chunk_size
*=
2
;
#endif
// Initial 1X input chunk exchange between neighboring peers
// Initial 1X input chunk exchange between neighboring peers
int
send_chunk_id
=
_tp_id
;
int
send_chunk_id
=
_tp_id
;
...
@@ -817,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -817,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
_stream_compute
[
i
%
_stream_compute
.
size
()]
,
1
,
0
);
if
(
i
<
num_steps
-
1
)
{
if
(
i
<
num_steps
-
1
)
{
// P2P communication
// P2P communication
...
@@ -861,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -861,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
_stream_compute
[
i
%
_stream_compute
.
size
()]
,
1
,
0
);
if
(
i
<
_tp_size
-
1
)
{
if
(
i
<
_tp_size
-
1
)
{
// P2P communication
// P2P communication
...
@@ -892,6 +896,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -892,6 +896,7 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapP2PBase::split_overlap_ag
}
// CommOverlapP2PBase::split_overlap_ag
/*
/*
...
@@ -1005,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1005,7 +1010,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]
,
1
,
0
);
if
(
i
>
0
)
{
if
(
i
>
0
)
{
// P2P communication chunk
// P2P communication chunk
...
@@ -1034,6 +1039,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1034,6 +1039,7 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
// Reduce GEMM output chunks
// Reduce GEMM output chunks
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
...
...
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
caf2fbf2
...
@@ -116,8 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -116,8 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id
++
;
reduce_id
++
;
}
}
__syncthreads
();
__syncthreads
();
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -201,8 +204,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -201,8 +204,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id
++
;
reduce_id
++
;
}
}
__syncthreads
();
__syncthreads
();
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -312,8 +318,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -312,8 +318,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -378,8 +387,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
...
@@ -378,8 +387,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -823,7 +835,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
...
@@ -823,7 +835,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -907,8 +923,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -907,8 +923,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
/*numchunks * */
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
/*numchunks * */
adder
);
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
(
reduce_id
/* + numchunks*/
))
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
(
reduce_id
/* + numchunks*/
))
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -988,7 +1007,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
...
@@ -988,7 +1007,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -1084,7 +1107,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1084,7 +1107,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -1181,7 +1208,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1181,7 +1208,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
for
(
int
i
=
0
;
i
<
RANKS
;
i
++
)
dest
[
i
]
=
(
i
+
myrank
+
warp
)
&
(
RANKS
-
1
);
...
@@ -1236,7 +1267,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1236,7 +1267,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
}
}
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
];
int
dest
[
RANKS
];
int
skipmy
=
0
;
int
skipmy
=
0
;
...
@@ -1314,7 +1349,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1314,7 +1349,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
__syncthreads
();
localptr
=
userptr
[
myrank
];
localptr
=
userptr
[
myrank
];
#ifdef __HIP_PLATFORM_AMD__
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
6
);
#else
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
int
warp
=
blockIdx
.
x
+
(
threadIdx
.
x
>>
5
);
#endif
int
dest
[
RANKS
-
1
];
int
dest
[
RANKS
-
1
];
int
skipmy
=
0
;
int
skipmy
=
0
;
#pragma unroll
#pragma unroll
...
@@ -1719,6 +1758,12 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1719,6 +1758,12 @@ __global__ void __launch_bounds__(MAX_THREADS)
kernelArgs)); \
kernelArgs)); \
}
}
#ifdef __HIP_PLATFORM_AMD__
#define WARPSIZE 64
#else
#define WARPSIZE 32
#endif
void
reducescatter2_userbuff_strided
(
void
*
output
,
const
int
handler
,
const
int
offset
,
void
reducescatter2_userbuff_strided
(
void
*
output
,
const
int
handler
,
const
int
offset
,
const
int
rowelements
,
const
int
colelements
,
const
int
rowelements
,
const
int
colelements
,
const
int
strideelements
,
communicator
*
comm
,
const
int
strideelements
,
communicator
*
comm
,
...
@@ -1733,10 +1778,10 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int
...
@@ -1733,10 +1778,10 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int
if
(
elements
<
64
)
return
;
if
(
elements
<
64
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
callranks_rs_oop_stride
(
2
)
callranks_rs_oop_stride
(
4
)
callranks_rs_oop_stride
(
8
)
callranks_rs_oop_stride
(
2
)
callranks_rs_oop_stride
(
4
)
callranks_rs_oop_stride
(
8
)
callranks_rs_oop_stride
(
16
)
callranks_rs_oop_stride
(
32
)
callranks_rs_oop_stride
(
16
)
callranks_rs_oop_stride
(
32
)
}
}
...
@@ -1755,10 +1800,10 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con
...
@@ -1755,10 +1800,10 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con
if
(
elements
<
64
)
return
;
if
(
elements
<
64
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
callranks_rs_oop_stride_atomic
(
2
)
callranks_rs_oop_stride_atomic
(
4
)
callranks_rs_oop_stride_atomic
(
2
)
callranks_rs_oop_stride_atomic
(
4
)
callranks_rs_oop_stride_atomic
(
8
)
callranks_rs_oop_stride_atomic
(
16
)
callranks_rs_oop_stride_atomic
(
8
)
callranks_rs_oop_stride_atomic
(
16
)
callranks_rs_oop_stride_atomic
(
32
)
callranks_rs_oop_stride_atomic
(
32
)
...
@@ -1782,10 +1827,10 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c
...
@@ -1782,10 +1827,10 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c
assert
(
comm
->
sm_arch
>=
9
);
assert
(
comm
->
sm_arch
>=
9
);
if
(
elements
<
128
)
return
;
if
(
elements
<
128
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
callranks_rs_oop_atomic_fp8
(
2
)
callranks_rs_oop_atomic_fp8
(
4
)
callranks_rs_oop_atomic_fp8
(
8
)
callranks_rs_oop_atomic_fp8
(
2
)
callranks_rs_oop_atomic_fp8
(
4
)
callranks_rs_oop_atomic_fp8
(
8
)
callranks_rs_oop_atomic_fp8
(
16
)
callranks_rs_oop_atomic_fp8
(
32
)
callranks_rs_oop_atomic_fp8
(
16
)
callranks_rs_oop_atomic_fp8
(
32
)
}
}
...
@@ -1827,10 +1872,10 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
...
@@ -1827,10 +1872,10 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
if
(
elements
<
64
)
return
;
if
(
elements
<
64
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
callranks_rs_oop_stride_multiatomic
(
2
)
callranks_rs_oop_stride_multiatomic
(
4
)
callranks_rs_oop_stride_multiatomic
(
2
)
callranks_rs_oop_stride_multiatomic
(
4
)
callranks_rs_oop_stride_multiatomic
(
8
)
callranks_rs_oop_stride_multiatomic
(
16
)
callranks_rs_oop_stride_multiatomic
(
8
)
callranks_rs_oop_stride_multiatomic
(
16
)
callranks_rs_oop_stride_multiatomic
(
32
)
callranks_rs_oop_stride_multiatomic
(
32
)
...
@@ -1848,18 +1893,18 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
...
@@ -1848,18 +1893,18 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
if
(
elements
<
64
)
return
;
if
(
elements
<
64
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
comm_launch_event
)
{
if
(
comm_launch_event
)
{
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
32
,
stream
,
comm_launch_event
);
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
WARPSIZE
,
stream
,
comm_launch_event
);
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
callranks_agMC
(
2
)
callranks_agMC
(
4
)
callranks_agMC
(
8
)
callranks_agMC
(
16
)
callranks_agMC
(
32
)
callranks_agMC
(
2
)
callranks_agMC
(
4
)
callranks_agMC
(
8
)
callranks_agMC
(
16
)
callranks_agMC
(
32
)
}
else
{
}
else
{
callranks_ag
(
2
)
callranks_ag
(
4
)
callranks_ag
(
8
)
callranks_ag
(
16
)
callranks_ag
(
32
)
callranks_ag
(
2
)
callranks_ag
(
4
)
callranks_ag
(
8
)
callranks_ag
(
16
)
callranks_ag
(
32
)
}
}
}
else
{
}
else
{
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
callranks_agMC
(
2
)
callranks_agMC
(
4
)
callranks_agMC
(
8
)
callranks_agMC
(
16
)
callranks_agMC
(
32
)
callranks_agMC
(
2
)
callranks_agMC
(
4
)
callranks_agMC
(
8
)
callranks_agMC
(
16
)
callranks_agMC
(
32
)
}
else
{
}
else
{
...
@@ -1895,18 +1940,18 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
...
@@ -1895,18 +1940,18 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
if
(
elements
<
64
)
return
;
if
(
elements
<
64
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
comm_launch_event
)
{
if
(
comm_launch_event
)
{
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
32
,
stream
,
comm_launch_event
);
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
WARPSIZE
,
stream
,
comm_launch_event
);
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
callranks_rsMC
(
2
)
callranks_rsMC
(
4
)
callranks_rsMC
(
8
)
callranks_rsMC
(
16
)
callranks_rsMC
(
32
)
callranks_rsMC
(
2
)
callranks_rsMC
(
4
)
callranks_rsMC
(
8
)
callranks_rsMC
(
16
)
callranks_rsMC
(
32
)
}
else
{
}
else
{
callranks_rs
(
2
)
callranks_rs
(
4
)
callranks_rs
(
8
)
callranks_rs
(
16
)
callranks_rs
(
32
)
callranks_rs
(
2
)
callranks_rs
(
4
)
callranks_rs
(
8
)
callranks_rs
(
16
)
callranks_rs
(
32
)
}
}
}
else
{
}
else
{
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
callranks_rsMC
(
2
)
callranks_rsMC
(
4
)
callranks_rsMC
(
8
)
callranks_rsMC
(
16
)
callranks_rsMC
(
32
)
callranks_rsMC
(
2
)
callranks_rsMC
(
4
)
callranks_rsMC
(
8
)
callranks_rsMC
(
16
)
callranks_rsMC
(
32
)
}
else
{
}
else
{
...
@@ -1928,11 +1973,11 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
...
@@ -1928,11 +1973,11 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
if
(
elements
<
64
)
return
;
if
(
elements
<
64
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
comm_launch_event
)
{
if
(
comm_launch_event
)
{
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
32
,
stream
,
comm_launch_event
);
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
WARPSIZE
,
stream
,
comm_launch_event
);
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
callranks_rs_oopMC
(
2
)
callranks_rs_oopMC
(
4
)
callranks_rs_oopMC
(
8
)
callranks_rs_oopMC
(
16
)
callranks_rs_oopMC
(
2
)
callranks_rs_oopMC
(
4
)
callranks_rs_oopMC
(
8
)
callranks_rs_oopMC
(
16
)
callranks_rs_oopMC
(
32
)
callranks_rs_oopMC
(
32
)
...
@@ -1941,7 +1986,7 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
...
@@ -1941,7 +1986,7 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
callranks_rs_oop
(
32
)
callranks_rs_oop
(
32
)
}
}
}
else
{
}
else
{
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
if
(
comm
->
use_mc
&&
(
comm
->
memflags
[
handler
]
&
NVTE_UB_MEM_MC_CREATED
))
{
callranks_rs_oopMC
(
2
)
callranks_rs_oopMC
(
4
)
callranks_rs_oopMC
(
8
)
callranks_rs_oopMC
(
16
)
callranks_rs_oopMC
(
2
)
callranks_rs_oopMC
(
4
)
callranks_rs_oopMC
(
8
)
callranks_rs_oopMC
(
16
)
callranks_rs_oopMC
(
32
)
callranks_rs_oopMC
(
32
)
...
@@ -1974,15 +2019,15 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
...
@@ -1974,15 +2019,15 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
assert
(
comm
->
sm_arch
>=
9
);
assert
(
comm
->
sm_arch
>=
9
);
if
(
elements
<
128
)
return
;
if
(
elements
<
128
)
return
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
sms
=
ar_nvsize
==
1
?
2
:
comm
->
sms
;
int
warps
=
comm
->
threads
/
32
;
int
warps
=
comm
->
threads
/
WARPSIZE
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
warps
<
ar_nvsize
)
warps
=
ar_nvsize
;
if
(
comm_launch_event
)
{
if
(
comm_launch_event
)
{
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
32
,
stream
,
comm_launch_event
);
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT
(
sms
,
warps
*
WARPSIZE
,
stream
,
comm_launch_event
);
callranks_rs_oop_fp8
(
2
)
callranks_rs_oop_fp8
(
4
)
callranks_rs_oop_fp8
(
8
)
callranks_rs_oop_fp8
(
16
)
callranks_rs_oop_fp8
(
2
)
callranks_rs_oop_fp8
(
4
)
callranks_rs_oop_fp8
(
8
)
callranks_rs_oop_fp8
(
16
)
callranks_rs_oop_fp8
(
32
)
callranks_rs_oop_fp8
(
32
)
}
else
{
}
else
{
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
32
,
stream
);
SETUP_LAUNCH_CONFIG
(
sms
,
warps
*
WARPSIZE
,
stream
);
callranks_rs_oop_fp8
(
2
)
callranks_rs_oop_fp8
(
4
)
callranks_rs_oop_fp8
(
8
)
callranks_rs_oop_fp8
(
16
)
callranks_rs_oop_fp8
(
2
)
callranks_rs_oop_fp8
(
4
)
callranks_rs_oop_fp8
(
8
)
callranks_rs_oop_fp8
(
16
)
callranks_rs_oop_fp8
(
32
)
callranks_rs_oop_fp8
(
32
)
}
}
...
...
transformer_engine/pytorch/cpp_extensions/gemm.py
View file @
caf2fbf2
...
@@ -93,11 +93,11 @@ def general_gemm(
...
@@ -93,11 +93,11 @@ def general_gemm(
transb
=
layout
[
1
]
==
"T"
transb
=
layout
[
1
]
==
"T"
# assert quantization_params is None, "FP8 output not supported yet"
# assert quantization_params is None, "FP8 output not supported yet"
if
ub_type
is
not
None
:
#
if ub_type is not None:
assert
ub
is
not
None
,
(
#
assert ub is not None, (
f
"
{
'AG+GEMM'
if
ub_type
==
tex
.
CommOverlapType
.
AG
else
'GEMM+RS'
}
overlap requires"
#
f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires"
+
"a valid `ub` communicator object."
#
+ "a valid `ub` communicator object."
)
#
)
if
ub
is
not
None
:
if
ub
is
not
None
:
assert
ub_type
is
not
None
,
"Comm+GEMM overlap requires a valid `comm_type` argument."
assert
ub_type
is
not
None
,
"Comm+GEMM overlap requires a valid `comm_type` argument."
...
...
transformer_engine/pytorch/module/base.py
View file @
caf2fbf2
...
@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
...
@@ -47,7 +47,7 @@ _multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace
=
[]
_multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_cublas_workspace
=
None
_ub_communicators
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
1
if
IS_HIP_EXTENSION
else
3
_NUM_MAX_UB_STREAMS
=
2
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
layers_atomic_ring_exchange
=
[]
...
@@ -92,6 +92,10 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
...
@@ -92,6 +92,10 @@ def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
)
)
return
_multi_stream_cublas_batchgemm_workspace
return
_multi_stream_cublas_batchgemm_workspace
if
bool
(
int
(
os
.
getenv
(
"NVTE_DISABLE_FC2_DGRAD_OVERLAP"
,
"0"
))):
remove_ag_gemm_dgrad
=
[
"fc2_dgrad"
]
else
:
remove_ag_gemm_dgrad
=
[]
def
initialize_ub
(
def
initialize_ub
(
shape
:
list
,
shape
:
list
,
...
@@ -237,6 +241,13 @@ def initialize_ub(
...
@@ -237,6 +241,13 @@ def initialize_ub(
layers_reduce_scatter_overlap
=
[
"proj_fprop"
,
"fc2_fprop"
,
"qkv_wgrad"
,
"fc1_wgrad"
]
layers_reduce_scatter_overlap
=
[
"proj_fprop"
,
"fc2_fprop"
,
"qkv_wgrad"
,
"fc1_wgrad"
]
dgrad_reduce_scatter_overlap
=
[
"qkv_dgrad"
,
"fc1_dgrad"
]
dgrad_reduce_scatter_overlap
=
[
"qkv_dgrad"
,
"fc1_dgrad"
]
# Default overlap methods for layers
# Default overlap methods for layers
if
bool
(
int
(
os
.
getenv
(
"NVTE_NO_PIPELINE_OVERLAP"
,
"0"
))):
methods
=
{
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
,
"proj_fprop"
,
"fc2_fprop"
],
"pipeline"
:
[],
"bulk"
:
[
"qkv_dgrad"
,
"qkv_wgrad"
,
"fc1_dgrad"
,
"fc1_wgrad"
],
}
else
:
methods
=
{
methods
=
{
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
],
"ring_exchange"
:
[
"qkv_fprop"
,
"fc1_fprop"
,
"proj_dgrad"
,
"fc2_dgrad"
],
"pipeline"
:
[
"proj_fprop"
,
"fc2_fprop"
],
"pipeline"
:
[
"proj_fprop"
,
"fc2_fprop"
],
...
@@ -264,7 +275,7 @@ def initialize_ub(
...
@@ -264,7 +275,7 @@ def initialize_ub(
default_cfg
=
{
default_cfg
=
{
"method"
:
method
,
"method"
:
method
,
"is_reduce_scatter"
:
is_reduce_scatter
,
"is_reduce_scatter"
:
is_reduce_scatter
,
"num_sm"
:
1
if
method
==
"ring_exchange"
else
16
,
"num_sm"
:
1
if
method
==
"ring_exchange"
else
8
,
"cga_size"
:
1
if
method
==
"ring_exchange"
else
2
,
"cga_size"
:
1
if
method
==
"ring_exchange"
else
2
,
"set_sm_margin"
:
not
method
==
"ring_exchange"
,
"set_sm_margin"
:
not
method
==
"ring_exchange"
,
"num_splits"
:
tp_size
if
method
==
"ring_exchange"
else
4
,
"num_splits"
:
tp_size
if
method
==
"ring_exchange"
else
4
,
...
@@ -377,6 +388,8 @@ def initialize_ub(
...
@@ -377,6 +388,8 @@ def initialize_ub(
methods
[
new_method
].
append
(
name
)
methods
[
new_method
].
append
(
name
)
for
name
in
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]:
for
name
in
methods
[
"ring_exchange"
]
+
methods
[
"pipeline"
]
+
methods
[
"bulk"
]:
if
name
in
remove_ag_gemm_dgrad
:
continue
ub_cfg
=
get_default_config
(
name
)
ub_cfg
=
get_default_config
(
name
)
if
ub_cfgs
is
not
None
and
name
in
ub_cfgs
:
if
ub_cfgs
is
not
None
and
name
in
ub_cfgs
:
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
fp8_buf
=
(
name
in
layers_all_gather_overlap
)
or
(
...
@@ -390,7 +403,9 @@ def initialize_ub(
...
@@ -390,7 +403,9 @@ def initialize_ub(
def
get_ub
(
name
:
str
):
def
get_ub
(
name
:
str
):
"""Get userbuffer communicator corresponding to give key."""
"""Get userbuffer communicator corresponding to give key."""
assert
_ub_communicators
is
not
None
,
"UB manager is not initialized."
assert
_ub_communicators
is
not
None
,
"UB manager is not initialized."
assert
name
in
_ub_communicators
,
f
"UB for
{
name
}
is not registered."
# assert name in _ub_communicators, f"UB for {name} is not registered."
if
name
in
remove_ag_gemm_dgrad
:
return
None
return
_ub_communicators
[
name
]
return
_ub_communicators
[
name
]
...
@@ -841,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -841,7 +856,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case.
# Non-FP8 case: bgrad is fused with wgrad for this case.
if
not
ctx
.
fp8
:
if
not
ctx
.
fp8
:
if
gather_grad_output
:
if
gather_grad_output
:
if
not
ctx
.
ub_overlap_ag
:
if
not
ctx
.
ub_overlap_ag
or
ctx
.
ub_obj_gradout
is
None
:
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
ctx
.
tp_group
)
grad_output
,
_
=
gather_along_first_dim
(
grad_output
,
ctx
.
tp_group
)
else
:
else
:
ctx
.
ub_obj_gradout
.
copy_into_buffer
(
grad_output
,
quantizer
,
local_chunk
=
True
)
ctx
.
ub_obj_gradout
.
copy_into_buffer
(
grad_output
,
quantizer
,
local_chunk
=
True
)
...
@@ -853,7 +868,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -853,7 +868,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_bias
=
None
grad_bias
=
None
if
ctx
.
use_bias
:
if
ctx
.
use_bias
:
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
grad_bias
=
grad_output
.
view
(
-
1
,
grad_output
.
shape
[
-
1
]).
sum
(
dim
=
0
)
if
ctx
.
ub_overlap_ag
:
if
ctx
.
ub_overlap_ag
and
ctx
.
ub_obj_gradout
is
not
None
:
# Quantize the gradient if needed
# Quantize the gradient if needed
if
not
isinstance
(
if
not
isinstance
(
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
)
grad_output
,
(
QuantizedTensor
,
Float8TensorBase
,
MXFP8TensorBase
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment