Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9e6e1871
Commit
9e6e1871
authored
May 20, 2025
by
yuguo
Browse files
Merge branch 'develop_v2.3' into 'main'
Develop v2.3 See merge request dcutoolkit/deeplearing/TransformerEngine!9
parents
9815d228
460b006c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
201 additions
and
76 deletions
+201
-76
tests/pytorch/distributed/run_layer_with_overlap.py
tests/pytorch/distributed/run_layer_with_overlap.py
+1
-1
tests/pytorch/test_batched_linear.py
tests/pytorch/test_batched_linear.py
+29
-4
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+76
-25
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
...ine/common/include/transformer_engine/comm_gemm_overlap.h
+2
-4
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+3
-2
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+90
-40
No files found.
tests/pytorch/distributed/run_layer_with_overlap.py
View file @
9e6e1871
...
@@ -496,7 +496,7 @@ def _train(opts):
...
@@ -496,7 +496,7 @@ def _train(opts):
if
opts
.
benchmark
:
if
opts
.
benchmark
:
# Warmup to not profile CPU overhead
# Warmup to not profile CPU overhead
for
_
in
range
(
20
):
for
_
in
range
(
opts
.
benchmark_iter
):
if
opts
.
use_cuda_graphs
:
if
opts
.
use_cuda_graphs
:
test_graph
.
replay
()
test_graph
.
replay
()
else
:
else
:
...
...
tests/pytorch/test_batched_linear.py
View file @
9e6e1871
...
@@ -171,7 +171,7 @@ def reset_global_fp8_state():
...
@@ -171,7 +171,7 @@ def reset_global_fp8_state():
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
def
_test_batched_linear_accuracy
(
def
_test_batched_linear_accuracy
(
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
block
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
):
):
reset_rng_states
()
reset_rng_states
()
if
fp8
:
if
fp8
:
...
@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
...
@@ -202,9 +202,31 @@ def _test_batched_linear_accuracy(
)
)
loss
=
out
.
sum
()
loss
=
out
.
sum
()
loss
.
backward
()
loss
.
backward
()
if
delay_wgrad_compute
:
if
isinstance
(
block
,
BatchedLinear
):
block
.
backward_dw
()
else
:
for
i
in
range
(
num_gemms
):
block
[
i
].
backward_dw
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
outputs
=
[
out
,
inp_hidden_states
.
grad
]
outputs
=
[
out
,
inp_hidden_states
.
grad
]
for
p
in
block
.
parameters
():
if
p
.
requires_grad
:
if
isinstance
(
block
,
BatchedLinear
):
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
for
j
in
range
(
batch_num
):
outputs
.
append
(
p
.
main_grad
[
p
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
p
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)])
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
for
j
in
range
(
batch_num
):
outputs
.
append
(
p
.
grad
[
p
.
grad
.
shape
[
0
]
//
batch_num
*
j
:
p
.
grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)])
else
:
if
getattr
(
p
,
"main_grad"
,
None
)
is
not
None
:
outputs
.
append
(
p
.
main_grad
)
assert
p
.
grad
is
None
# grad should be None if fuse_wgrad_accumulation is True
else
:
outputs
.
append
(
p
.
grad
)
return
outputs
return
outputs
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
...
@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
...
@@ -215,6 +237,7 @@ def _test_batched_linear_accuracy(
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"recipe"
,
fp8_recipes
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fp8_model_params"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"fuse_wgrad_accumulation"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"delay_wgrad_compute"
,
all_boolean
)
def
test_batched_linear_accuracy
(
def
test_batched_linear_accuracy
(
dtype
,
dtype
,
num_gemms
,
num_gemms
,
...
@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
...
@@ -224,6 +247,7 @@ def test_batched_linear_accuracy(
recipe
,
recipe
,
fp8_model_params
,
fp8_model_params
,
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
parallel_mode
=
None
,
parallel_mode
=
None
,
):
):
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
batch_num
=
int
(
os
.
getenv
(
"NVTE_MOE_BATCHCOUNT"
,
"2"
))
...
@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
...
@@ -250,6 +274,7 @@ def test_batched_linear_accuracy(
parallel_mode
=
parallel_mode
,
parallel_mode
=
parallel_mode
,
device
=
"cuda"
,
device
=
"cuda"
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
fuse_wgrad_accumulation
=
fuse_wgrad_accumulation
,
delay_wgrad_compute
=
delay_wgrad_compute
,
).
eval
()
).
eval
()
sequential_linear
=
torch
.
nn
.
ModuleList
(
sequential_linear
=
torch
.
nn
.
ModuleList
(
[
[
...
@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
...
@@ -281,10 +306,10 @@ def test_batched_linear_accuracy(
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
sequential_linear
[
i
*
batch_num
+
j
].
weight
.
main_grad
=
weight_i
.
main_grad
[
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
j
:
weight_i
.
main_grad
.
shape
[
0
]
//
batch_num
*
(
j
+
1
)].
clone
()
outputs_ref
=
_test_batched_linear_accuracy
(
outputs_ref
=
_test_batched_linear_accuracy
(
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
sequential_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
)
)
outputs
=
_test_batched_linear_accuracy
(
outputs
=
_test_batched_linear_accuracy
(
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
batched_linear
,
num_gemms
,
bs
,
dtype
,
config
,
recipe
,
fp8
,
fuse_wgrad_accumulation
,
delay_wgrad_compute
,
batch_num
)
)
# Shoule be bit-wise match
# Shoule be bit-wise match
...
@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
...
@@ -292,4 +317,4 @@ def test_batched_linear_accuracy(
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
torch
.
testing
.
assert_close
(
o
,
o_ref
,
rtol
=
6e-3
,
atol
=
6e-3
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
)
test_batched_linear_accuracy
(
torch
.
float32
,
2
,
1
,
"126m"
,
False
,
recipe
.
Float8CurrentScaling
(),
True
,
True
,
True
)
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
9e6e1871
...
@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
...
@@ -68,10 +68,12 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
_gemm_priority
=
gemm_priority
;
_gemm_priority
=
gemm_priority
;
_comm_priority
=
comm_priority
;
_comm_priority
=
comm_priority
;
}
}
static
cudaStream_t
compute_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
cudaStream_t
stream
;
if
(
compute_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
stream
,
cudaStreamNonBlocking
,
_gemm_priority
));
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
_gemm_priority
));
_stream_compute
.
push_back
(
std
::
move
(
stream
));
}
_stream_compute
.
push_back
(
compute_streams
[
i
]);
}
}
_num_splits
=
num_splits
;
_num_splits
=
num_splits
;
...
@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
...
@@ -225,6 +227,7 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
allgather_handle
,
barrier_handle
,
num_splits
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
false
,
atomic_gemm
)
{
atomic_gemm
)
{
_ub_stream_nums
=
num_max_streams
;
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_overlap_first_gemm
=
rs_overlap_first_gemm
;
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
_rs_kernel_type
=
getenv
<
int
>
(
"NVTE_RS_STRIDED_ATOMIC"
,
0
);
NVTE_CHECK
(
_rs_kernel_type
>=
0
&&
_rs_kernel_type
<=
3
,
NVTE_CHECK
(
_rs_kernel_type
>=
0
&&
_rs_kernel_type
<=
3
,
...
@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
...
@@ -238,8 +241,12 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
if
(
_ub_comm
->
myrank
==
0
)
printf
(
"!!! [UB] Register UBuf %d
\n
"
,
_ub_reg
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
_ubuf
=
TensorWrapper
(
buffer_ptr
,
buffer_shape
,
buffer_dtype
);
NVTE_CHECK_CUDA
(
static
cudaStream_t
comm_stream
;
cudaStreamCreateWithPriority
(
&
_stream_comm
,
cudaStreamNonBlocking
,
_comm_priority
));
if
(
comm_stream
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
comm_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
}
_stream_comm
=
comm_stream
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
}
}
...
@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
...
@@ -307,7 +314,6 @@ void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const Te
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_compute
[
0
]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapBase::bulk_overlap
}
// CommOverlapBase::bulk_overlap
...
@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -444,9 +450,15 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
{
m
,
m_chunk
});
auto
output_chunk
=
get_buffer_chunk_like
(
D
,
0
,
{
m
,
m_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
0
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
],
1
,
0
,
0
);
use_split_accumulator
,
_math_sms
,
_stream_compute
[
0
],
1
,
0
,
0
);
}
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
for
(
int
i
=
1
;
i
<
_num_splits
;
i
++
)
{
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
input_a_chunk
=
get_tensor_chunk
(
A
,
i
*
input_a_chunk_size
,
{
m_chunk
,
k
});
...
@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -454,10 +466,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
workspace_chunk
=
get_tensor_chunk
(
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
NVTE_CHECK_CUDA
(
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
cudaEventRecord
(
_start_comm
,
_stream_compute
[(
i
-
1
)
%
_stream_compute
.
size
()]));
...
@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -502,10 +521,17 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
input_a_chunk
.
data
(),
B
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_comm
,
_stream_compute
[
i
%
_stream_compute
.
size
()]));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
_stream_comm
,
_start_comm
,
0
));
...
@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
...
@@ -536,7 +562,6 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons
}
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_comm
,
_stream_comm
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_comm
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapBase::split_overlap_rs
}
// CommOverlapBase::split_overlap_rs
/***************************************************************************************************
/***************************************************************************************************
...
@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
...
@@ -555,6 +580,8 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
allgather_handle
,
barrier_handle
,
tp_size
,
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
)
{
atomic_gemm
)
{
_ub_stream_nums
=
num_max_streams
;
_is_p2p
=
true
;
_is_p2p
=
true
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_is_reduce_scatter
=
comm_type
==
CommOverlapType
::
RS
;
_aggregate
=
aggregate
;
_aggregate
=
aggregate
;
...
@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
...
@@ -603,13 +630,19 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector<size_t> &buffer_shape,
NVTE_CHECK_CUDA
(
cudaMemset
(
_counter
.
dptr
(),
0
,
sizeof
(
int32_t
)));
NVTE_CHECK_CUDA
(
cudaMemset
(
_counter
.
dptr
(),
0
,
sizeof
(
int32_t
)));
}
}
static
cudaStream_t
send_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
static
cudaStream_t
recv_stream
;
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
_tp_size
);
i
++
)
{
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
_tp_size
);
i
++
)
{
cudaStream_t
stream
;
if
(
send_streams
[
i
]
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
stream
,
cudaStreamNonBlocking
,
_comm_priority
));
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
send_streams
[
i
],
cudaStreamNonBlocking
,
_comm_priority
));
_stream_send
.
push_back
(
std
::
move
(
stream
));
}
_stream_send
.
push_back
(
send_streams
[
i
]);
}
if
(
recv_stream
==
nullptr
)
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
recv_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
}
}
NVTE_CHECK_CUDA
(
_stream_recv
=
recv_stream
;
cudaStreamCreateWithPriority
(
&
_stream_recv
,
cudaStreamNonBlocking
,
_comm_priority
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_stop_recv
,
0
));
}
}
...
@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -813,10 +846,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
if
(
i
<
num_steps
-
1
)
{
if
(
i
<
num_steps
-
1
)
{
// P2P communication
// P2P communication
...
@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -857,10 +897,17 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
get_tensor_chunk
(
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
workspace
,
(
i
%
_stream_compute
.
size
())
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()]);
}
else
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
aux_chunk
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
_stream_compute
[
i
%
_stream_compute
.
size
()],
1
,
0
,
i
%
_stream_compute
.
size
());
}
if
(
i
<
_tp_size
-
1
)
{
if
(
i
<
_tp_size
-
1
)
{
// P2P communication
// P2P communication
...
@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
...
@@ -891,7 +938,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_send
,
0
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
}
// CommOverlapP2PBase::split_overlap_ag
}
// CommOverlapP2PBase::split_overlap_ag
/*
/*
...
@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1003,9 +1049,15 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
auto
workspace_chunk
=
auto
workspace_chunk
=
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
get_tensor_chunk
(
workspace
,
stream_id
*
workspace_size_chunk
,
{
workspace_size_chunk
});
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
if
(
_ub_stream_nums
==
1
)
{
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
,
stream_id
);
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
]);
}
else
{
nvte_cublas_gemm
(
A
.
data
(),
input_b_chunk
.
data
(),
output_chunk
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
workspace_chunk
.
data
(),
accumulate
,
use_split_accumulator
,
_math_sms
,
_stream_compute
[
stream_id
],
1
,
0
,
stream_id
);
}
if
(
i
>
0
)
{
if
(
i
>
0
)
{
// P2P communication chunk
// P2P communication chunk
...
@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
...
@@ -1034,7 +1086,6 @@ void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa,
}
}
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_stop_recv
,
_stream_recv
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream_main
,
_stop_recv
,
0
));
NVTE_CHECK_CUDA
(
cudaDeviceSynchronize
());
// Reduce GEMM output chunks
// Reduce GEMM output chunks
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
char
*
reduce_buf_ptr
=
reinterpret_cast
<
char
*>
(
_ubufs
[
_tp_size
-
1
].
dptr
());
...
...
transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h
View file @
9e6e1871
...
@@ -15,11 +15,7 @@
...
@@ -15,11 +15,7 @@
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#ifdef __HIP_PLATFORM_AMD__
#define NVTE_COMM_OVERLAP_MAX_STREAMS 1
#else
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#define NVTE_COMM_OVERLAP_MAX_STREAMS 3
#endif
namespace
transformer_engine
{
namespace
transformer_engine
{
...
@@ -141,6 +137,7 @@ class CommOverlapCore {
...
@@ -141,6 +137,7 @@ class CommOverlapCore {
class
CommOverlapBase
:
public
CommOverlapCore
{
class
CommOverlapBase
:
public
CommOverlapCore
{
protected:
protected:
int
_ub_stream_nums
;
int
_rs_kernel_type
;
int
_rs_kernel_type
;
bool
_rs_overlap_first_gemm
;
bool
_rs_overlap_first_gemm
;
cudaStream_t
_stream_comm
;
cudaStream_t
_stream_comm
;
...
@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
...
@@ -206,6 +203,7 @@ class CommOverlapBase : public CommOverlapCore {
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
class
CommOverlapP2PBase
:
public
CommOverlapCore
{
protected:
protected:
int
_ub_stream_nums
;
bool
_is_reduce_scatter
{
false
};
bool
_is_reduce_scatter
{
false
};
bool
_use_multiatomic_ag
{
false
};
bool
_use_multiatomic_ag
{
false
};
bool
_aggregate
;
bool
_aggregate
;
...
...
transformer_engine/pytorch/module/base.py
View file @
9e6e1871
...
@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
...
@@ -52,10 +52,11 @@ _2X_ACC_DGRAD = True
_2X_ACC_WGRAD
=
True
_2X_ACC_WGRAD
=
True
_multi_stream_cublas_workspace
=
[]
_multi_stream_cublas_workspace
=
[]
_dummy_wgrads
=
{}
_dummy_wgrads
=
{}
multi_stream_cublas_batchgemm_workspace
=
[]
_
multi_stream_cublas_batchgemm_workspace
=
[]
_cublas_workspace
=
None
_cublas_workspace
=
None
_ub_communicators
=
None
_ub_communicators
=
None
_NUM_MAX_UB_STREAMS
=
2
if
IS_HIP_EXTENSION
else
3
ub_stream_nums
=
int
(
os
.
getenv
(
"NVTE_UB_STREAM_NUMS"
,
"2"
))
_NUM_MAX_UB_STREAMS
=
ub_stream_nums
if
IS_HIP_EXTENSION
else
3
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
_MIN_STREAM_PRIORITY
,
_MAX_STREAM_PRIORITY
=
None
,
None
layers_atomic_ring_exchange
=
[]
layers_atomic_ring_exchange
=
[]
...
...
transformer_engine/pytorch/module/batched_linear.py
View file @
9e6e1871
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
import
os
import
os
import
logging
import
logging
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
Tuple
,
Union
,
List
import
functools
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
...
@@ -18,6 +18,7 @@ from .base import (
...
@@ -18,6 +18,7 @@ from .base import (
_2X_ACC_DGRAD
,
_2X_ACC_DGRAD
,
_2X_ACC_WGRAD
,
_2X_ACC_WGRAD
,
)
)
from
._common
import
WeightGradStore
from
..fp8
import
get_fp8_te_dtype
,
FP8GlobalStateManager
from
..fp8
import
get_fp8_te_dtype
,
FP8GlobalStateManager
from
..utils
import
(
from
..utils
import
(
divide
,
divide
,
...
@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -82,6 +83,7 @@ class _BatchLinear(torch.autograd.Function):
is_first_microbatch
:
Union
[
bool
,
None
],
is_first_microbatch
:
Union
[
bool
,
None
],
fp8
:
bool
,
fp8
:
bool
,
fp8_calibration
:
bool
,
fp8_calibration
:
bool
,
wgrad_store
:
WeightGradStore
,
fp8_meta
:
Dict
[
str
,
Any
],
fp8_meta
:
Dict
[
str
,
Any
],
fuse_wgrad_accumulation
:
bool
,
fuse_wgrad_accumulation
:
bool
,
cpu_offloading
:
bool
,
cpu_offloading
:
bool
,
...
@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -183,6 +185,7 @@ class _BatchLinear(torch.autograd.Function):
ctx
.
tp_size
=
tp_size
ctx
.
tp_size
=
tp_size
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
requires_dgrad
=
inp
.
requires_grad
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
reduce_and_update_bwd_fp8_tensors
=
False
ctx
.
wgrad_store
=
wgrad_store
# [*, in_features] -> [*, out_features] except first dimension changes for SP
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
return
out
.
view
(
-
1
,
*
inp
.
shape
[
1
:
-
1
],
out
.
shape
[
-
1
])
...
@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -246,53 +249,69 @@ class _BatchLinear(torch.autograd.Function):
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
torch
.
empty
(
w
.
size
(),
dtype
=
ctx
.
activation_dtype
,
device
=
w
.
device
)
for
w
in
weights
for
w
in
weights
]
]
# WGRAD
batched_gemm_wgrad
=
functools
.
partial
(
_
,
grad_biases
,
_
=
batchgemm
(
batchgemm
,
inputmats
,
dtype
=
ctx
.
activation_dtype
,
grad_output_mats
,
workspaces
=
get_multi_stream_cublas_batchgemm_workspace
(),
wgrad_list
,
ctx
.
activation_dtype
,
get_multi_stream_cublas_batchgemm_workspace
(),
layout
=
"NT"
,
layout
=
"NT"
,
grad
=
True
,
grad
=
True
,
use_bias
=
ctx
.
use_bias
,
use_bias
=
ctx
.
use_bias
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
accumulate
=
accumulate_wgrad_into_param_main_grad
,
)
)
# WGRAD
# Deallocate input tensor
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
clear_tensor_data
(
*
inputmats
)
ctx
.
wgrad_store
.
put
([
inputmats
,
grad_output_mats
,
wgrad_list
],
batched_gemm_wgrad
)
clear_tensor_data
(
*
inputmats_t
)
else
:
_
,
grad_biases_
,
_
=
batched_gemm_wgrad
(
inputmats
,
grad_output_mats
,
wgrad_list
)
if
not
ctx
.
use_bias
:
grad_biases
=
[
None
]
*
ctx
.
num_gemms
for
i
in
range
(
ctx
.
num_gemms
):
if
grad_biases
[
i
]
is
None
:
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
grad_biases
[
i
]
=
grad_biases_
[
i
]
if
w
.
requires_grad
:
del
grad_biases_
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
w
.
grad_added_to_main_grad
=
True
# Deallocate input tensor
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
clear_tensor_data
(
*
inputmats
)
wgrad
=
torch
.
zeros
(
clear_tensor_data
(
*
inputmats_t
)
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
def
handle_custom_ddp_from_mcore
(
w
,
wgrad
):
device
=
torch
.
cuda
.
current_device
(),
if
w
.
requires_grad
:
requires_grad
=
False
,
if
ctx
.
fuse_wgrad_accumulation
and
hasattr
(
w
,
"grad_added_to_main_grad"
):
)
w
.
grad_added_to_main_grad
=
True
if
getattr
(
w
,
"zero_out_wgrad"
,
False
):
wgrad
=
torch
.
zeros
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
else
:
wgrad
=
torch
.
empty
(
w
.
main_grad
.
shape
,
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
,
)
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
else
:
wgrad
=
torch
.
empty
(
wgrad
=
None
w
.
main_grad
.
shape
,
return
wgrad
dtype
=
w
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
wgrad_list
=
[
requires_grad
=
False
,
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
)
]
elif
ctx
.
fuse_wgrad_accumulation
:
wgrad
=
None
else
:
else
:
wgrad
=
None
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
return
wgrad
if
ctx
.
wgrad_store
is
not
None
and
ctx
.
wgrad_store
.
delay_wgrad_compute
():
wgrad_list
=
[
None
]
*
ctx
.
num_gemms
wgrad_list
=
[
if
not
ctx
.
use_bias
or
(
handle_custom_ddp_from_mcore
(
w
,
wgrad
)
for
w
,
wgrad
in
zip
(
weights
,
wgrad_list
)
ctx
.
wgrad_store
is
not
None
]
and
ctx
.
wgrad_store
.
delay_wgrad_compute
()
and
not
ctx
.
fp8
):
grad_biases
=
[
None
]
*
ctx
.
num_gemms
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
if
ctx
.
reduce_and_update_bwd_fp8_tensors
and
not
is_graph_capturing
():
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
FP8GlobalStateManager
.
reduce_and_update_fp8_tensors
(
forward
=
False
)
...
@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
...
@@ -304,6 +323,7 @@ class _BatchLinear(torch.autograd.Function):
None
,
# is_first_microbatch
None
,
# is_first_microbatch
None
,
# fp8
None
,
# fp8
None
,
# fp8_calibration
None
,
# fp8_calibration
None
,
# wgrad_store
None
,
# fp8_meta
None
,
# fp8_meta
None
,
# fuse_wgrad_accumulation
None
,
# fuse_wgrad_accumulation
None
,
# cpu_offloading
None
,
# cpu_offloading
...
@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -381,6 +401,8 @@ class BatchedLinear(TransformerEngineBaseModule):
it controls the type used to allocate the initial parameters. Useful when
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
would not fit in GPU memory.
delay_wgrad_compute : bool, default = `False`
Whether to delay weight gradient computation
"""
"""
def
__init__
(
def
__init__
(
...
@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -403,6 +425,7 @@ class BatchedLinear(TransformerEngineBaseModule):
ub_overlap_rs
:
bool
=
False
,
ub_overlap_rs
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_overlap_ag
:
bool
=
False
,
ub_name
:
Optional
[
str
]
=
None
,
ub_name
:
Optional
[
str
]
=
None
,
delay_wgrad_compute
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -424,6 +447,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
get_rng_state_tracker
=
get_rng_state_tracker
self
.
rng_tracker_name
=
rng_tracker_name
self
.
rng_tracker_name
=
rng_tracker_name
self
.
wgrad_store
=
WeightGradStore
(
delay_wgrad_compute
)
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
global
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
=
0
,
self
.
num_gemms
,
2
*
self
.
num_gemms
_GEMM_INPUT
,
_GEMM_WEIGHT
,
_GEMM_OUTPUT
=
0
,
self
.
num_gemms
,
2
*
self
.
num_gemms
...
@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -588,6 +613,7 @@ class BatchedLinear(TransformerEngineBaseModule):
is_first_microbatch
,
is_first_microbatch
,
self
.
fp8
,
self
.
fp8
,
self
.
fp8_calibration
,
self
.
fp8_calibration
,
self
.
wgrad_store
,
self
.
fp8_meta
,
self
.
fp8_meta
,
self
.
fuse_wgrad_accumulation
,
self
.
fuse_wgrad_accumulation
,
CPUOffloadEnabled
,
CPUOffloadEnabled
,
...
@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
...
@@ -617,3 +643,27 @@ class BatchedLinear(TransformerEngineBaseModule):
if
self
.
return_bias
:
if
self
.
return_bias
:
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
,
[
cast_if_needed
(
b
,
self
.
activation_dtype
)
for
b
in
bias_tensors
]
return
out
return
out
def
backward_dw
(
self
):
"""
Execute the delayed weight gradient computation.
This method is called after the main backward pass to compute weight gradients.
"""
if
self
.
wgrad_store
is
None
or
not
self
.
wgrad_store
.
delay_wgrad_compute
():
return
with
torch
.
cuda
.
nvtx
.
range
(
"_GroupedLinear_wgrad"
):
(
_
,
grad_biases_
,
_
),
tensor_list
=
self
.
wgrad_store
.
pop
()
wgrad_list
=
tensor_list
[
2
]
if
not
self
.
fuse_wgrad_accumulation
:
for
i
in
range
(
self
.
num_gemms
):
weight_param
=
getattr
(
self
,
f
"weight
{
i
}
"
)
if
weight_param
.
grad
is
None
:
weight_param
.
grad
=
wgrad_list
[
i
].
to
(
weight_param
.
dtype
)
if
self
.
use_bias
:
for
i
in
range
(
self
.
num_gemms
):
bias_param
=
getattr
(
self
,
f
"bias
{
i
}
"
)
if
bias_param
.
grad
is
None
:
bias_param
.
grad
=
grad_biases_
[
i
].
to
(
bias_param
.
dtype
)
del
grad_biases_
del
wgrad_list
del
tensor_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment