Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
820 additions
and
426 deletions
+820
-426
transformer_engine/pytorch/csrc/extensions/cast.cpp
transformer_engine/pytorch/csrc/extensions/cast.cpp
+22
-9
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
...rmer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
+113
-120
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
...ytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
+51
-0
transformer_engine/pytorch/csrc/extensions/gemm.cpp
transformer_engine/pytorch/csrc/extensions/gemm.cpp
+73
-61
transformer_engine/pytorch/csrc/extensions/misc.cpp
transformer_engine/pytorch/csrc/extensions/misc.cpp
+4
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
...rmer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
+98
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
...ne/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
+24
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
...er_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
+105
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
...mer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
+22
-0
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
...ormer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
+25
-0
transformer_engine/pytorch/csrc/extensions/normalization.cpp
transformer_engine/pytorch/csrc/extensions/normalization.cpp
+78
-69
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
+4
-2
transformer_engine/pytorch/csrc/extensions/padding.cpp
transformer_engine/pytorch/csrc/extensions/padding.cpp
+12
-8
transformer_engine/pytorch/csrc/extensions/permutation.cpp
transformer_engine/pytorch/csrc/extensions/permutation.cpp
+48
-42
transformer_engine/pytorch/csrc/extensions/pybind.cpp
transformer_engine/pytorch/csrc/extensions/pybind.cpp
+107
-80
transformer_engine/pytorch/csrc/extensions/recipe.cpp
transformer_engine/pytorch/csrc/extensions/recipe.cpp
+6
-8
transformer_engine/pytorch/csrc/extensions/softmax.cpp
transformer_engine/pytorch/csrc/extensions/softmax.cpp
+4
-14
transformer_engine/pytorch/csrc/extensions/transpose.cpp
transformer_engine/pytorch/csrc/extensions/transpose.cpp
+10
-10
transformer_engine/pytorch/csrc/pybind.h
transformer_engine/pytorch/csrc/pybind.h
+12
-0
transformer_engine/pytorch/csrc/quantizer.cpp
transformer_engine/pytorch/csrc/quantizer.cpp
+2
-3
No files found.
transformer_engine/pytorch/csrc/extensions/cast.cpp
View file @
f8c2af4c
...
...
@@ -52,7 +52,9 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
if
(
detail
::
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
te_input
.
data
(),
te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
my_quantizer_cs
->
amax_reduction_group
;
...
...
@@ -69,7 +71,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
// so in nvte_quantize_v2 with current scaling, the quant config is not used again
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
te_output
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
te_output
.
defaultShape
);
}
else
if
(
detail
::
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
...
...
@@ -77,8 +82,10 @@ py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::ob
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
te_input
.
data
(),
te_output
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
out
;
}
...
...
@@ -96,7 +103,9 @@ py::object dequantize(const py::handle& input, transformer_engine::DType otype)
auto
[
out_tensor
,
out
]
=
q
.
create_tensor
(
shape
,
otype
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_dequantize
(
input_tensor
.
data
(),
out_tensor
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
return
out
;
}
...
...
@@ -120,15 +129,19 @@ std::vector<py::object> dbias_dact(const at::Tensor& grad_output, const at::Tens
// Query workspace size and allocate workspace
transformer_engine
::
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
func
(
grad_tensor
.
data
(),
act_input_tensor
.
data
(),
dact_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
workspace
=
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
func
(
grad_tensor
.
data
(),
act_input_tensor
.
data
(),
dact_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
grad_bias
),
dact
};
}
...
...
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
View file @
f8c2af4c
...
...
@@ -141,81 +141,79 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
num_max_streams
,
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
atomic_gemm
,
rs_overlap_first_gemm
)
{}
void
CommOverlap
::
set_buffer_params
(
py
::
handle
quantizer
)
{
std
::
unique_ptr
<
te
::
pytorch
::
Quantizer
>
my_quantizer
=
te
::
pytorch
::
convert_quantizer
(
quantizer
);
my_quantizer
->
set_quantization_params
(
&
_ubuf
);
_ubuf_scale_inv_initialized
=
true
;
}
/*
** Helper function to copy input to _ubuf
*/
void
CommOverlap
::
copy_into_buffer
(
py
::
handle
input
,
py
::
handle
quantizer
,
bool
local_chunk
)
{
auto
input_tensor
=
te
::
pytorch
::
makeTransformerEngineTensor
(
input
,
quantizer
);
auto
input_ptr
=
input_tensor
.
dptr
()
?
input_tensor
.
dptr
()
:
input_tensor
.
columnwise_dptr
();
NVTE_CHECK
(
input_ptr
,
"Input tensor does not have rowwise or columnwise data!"
);
char
*
ubuf_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
());
void
CommOverlap
::
copy_into_buffer
(
const
at
::
Tensor
&
input
,
bool
local_chunk
)
{
const
auto
&
input_
=
input
.
contiguous
();
// Check element size
const
size_t
element_size
=
input
.
element_size
();
NVTE_CHECK
(
_ubuf
.
element_size
()
==
element_size
,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible "
,
"(input dtype has "
,
element_size
,
" bytes, UB dtype has "
,
_ubuf
.
element_size
(),
" bytes)"
);
// Input data
const
size_t
input_size
=
input_
.
numel
();
const
void
*
src_ptr
=
input_
.
data_ptr
();
// Userbuffers data
const
size_t
ubuf_size
=
_ubuf
.
numel
();
void
*
dst_ptr
=
_ubuf
.
dptr
();
if
(
local_chunk
)
{
if
(
input_tensor
.
numel
()
*
_tp_size
>
_
ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the local communication
buffer
!"
);
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element
_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!
"
);
ubuf
_ptr
+
=
(
_ubuf
.
numel
()
/
_tp_size
)
*
_tp_id
*
_ubuf
.
element_size
(
);
NVTE_CHECK
(
input_size
*
_tp_size
==
ubuf
_size
,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers
buffer
"
,
"(input_size="
,
input_size
,
", tensor_parallel_size="
,
_tp
_size
,
", ubuf_size="
,
ubuf_size
,
")
"
);
dst
_ptr
=
(
reinterpret_cast
<
char
*>
(
dst_ptr
)
+
(
ubuf_size
/
_tp_size
)
*
_tp_id
*
element_size
);
}
else
{
if
(
input_tensor
.
numel
()
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK
(
input_size
==
ubuf_size
,
"Tried to copy an invalid tensor into a Userbuffers buffer "
,
"(input_size="
,
input_size
,
", ubuf_size="
,
ubuf_size
,
")"
);
}
// Copy either row or columnwise data into the communication buffer's columnwise data
// NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with
// the Userbuffers communicator.
at
::
cuda
::
CUDAStream
stream_main
=
at
::
cuda
::
getCurrentCUDAStream
();
// Copy data
auto
stream_main
=
at
::
cuda
::
getCurrentCUDAStream
();
NVTE_CHECK_CUDA
(
cudaEventRecord
(
_start_d2dcopy
,
(
cudaStream_t
)
stream_main
));
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
((
cudaStream_t
)
_stream_comm
,
_start_d2dcopy
,
0
));
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
ubuf_ptr
,
input_tensor
.
dptr
(),
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
dst_ptr
,
src_ptr
,
input_size
*
element_size
,
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
_stream_comm
));
}
py
::
object
CommOverlap
::
get_buffer
(
py
::
handle
quantizer
,
bool
local_chunk
,
std
::
optional
<
const
std
::
vector
<
int64_t
>>
shape
)
{
using
namespace
te
::
pytorch
;
char
*
ubuf_wt_ptr
=
reinterpret_cast
<
char
*>
(
_ubuf
.
dptr
());
if
(
local_chunk
)
ubuf_wt_ptr
+=
_ubuf
.
numel
()
/
_tp_size
*
_tp_id
*
_ubuf
.
element_size
();
std
::
vector
<
int64_t
>
torch_shape
;
if
(
shape
.
has_value
())
{
torch_shape
=
shape
.
value
();
size_t
requested
=
product
(
torch_shape
);
auto
expected
=
local_chunk
?
_ubuf
.
numel
()
/
_tp_size
:
_ubuf
.
numel
();
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
") does not match allocated buffer size ("
,
expected
,
")!"
);
at
::
Tensor
CommOverlap
::
get_buffer
(
bool
local_chunk
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
)
{
// Check buffer shape
const
size_t
ubuf_size
=
_ubuf
.
numel
();
if
(
shape
)
{
const
size_t
requested_size
=
transformer_engine
::
pytorch
::
product
(
*
shape
);
if
(
local_chunk
)
{
NVTE_CHECK
(
requested_size
*
_tp_size
==
ubuf_size
,
"Invalid shape for local chunk of a Userbuffers buffer (requested shape="
,
*
shape
,
", tensor_parallel_size="
,
_tp_size
,
", ubuf_size="
,
ubuf_size
,
")"
);
}
else
{
NVTE_CHECK
(
requested_size
==
ubuf_size
,
"Invalid shape for a Userbuffers buffer (requested shape="
,
*
shape
,
", ubuf_size="
,
ubuf_size
,
")"
);
}
}
else
{
int64_t
output_c_dim0
=
(
local_chunk
)
?
_ubuf
.
size
(
0
)
/
_tp_size
:
_ubuf
.
size
(
0
);
int64_t
output_c_dim1
=
_ubuf
.
size
(
1
);
torch_shape
=
{
output_c_dim0
,
output_c_dim1
};
int64_t
dim0
=
_ubuf
.
size
(
0
);
int64_t
dim1
=
_ubuf
.
size
(
1
);
if
(
local_chunk
)
{
dim0
/=
_tp_size
;
}
shape
=
{
dim0
,
dim1
};
}
// Data pointer
void
*
ubuf_ptr
=
_ubuf
.
dptr
();
if
(
local_chunk
)
{
ubuf_ptr
=
(
reinterpret_cast
<
char
*>
(
ubuf_ptr
)
+
(
ubuf_size
/
_tp_size
)
*
_tp_id
*
_ubuf
.
element_size
());
}
auto
ubuf_tensor
=
torch
::
from_blob
(
reinterpret_cast
<
void
*>
(
ubuf_wt_ptr
),
torch_shape
,
at
::
dtype
(
GetATenDType
(
_ubuf
.
dtype
())).
device
(
torch
::
kCUDA
));
std
::
unique_ptr
<
Quantizer
>
my_quantizer
=
convert_quantizer
(
quantizer
);
std
::
vector
<
size_t
>
te_shape
;
for
(
auto
s
:
torch_shape
)
te_shape
.
emplace_back
(
static_cast
<
size_t
>
(
s
));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto
is_internal
=
my_quantizer
->
internal
;
auto
uses_columnwise
=
my_quantizer
->
columnwise_usage
;
my_quantizer
->
internal
=
false
;
my_quantizer
->
columnwise_usage
=
false
;
auto
[
te_tensor
,
py_tensor
]
=
my_quantizer
->
create_tensor
(
te_shape
,
_ubuf
.
dtype
(),
ubuf_tensor
);
my_quantizer
->
internal
=
is_internal
;
my_quantizer
->
columnwise_usage
=
uses_columnwise
;
return
py_tensor
;
// Construct PyTorch tensor
const
auto
dtype
=
transformer_engine
::
pytorch
::
GetATenDType
(
_ubuf
.
dtype
());
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
/***************************************************************************************************
...
...
@@ -236,74 +234,69 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
comm_cga_size
,
gemm_priority
,
comm_priority
,
num_comm_sm
,
set_sm_margin
,
use_ce
,
atomic_gemm
,
aggregate
)
{}
void
CommOverlapP2P
::
set_buffer_params
(
py
::
handle
quantizer
)
{
std
::
unique_ptr
<
te
::
pytorch
::
Quantizer
>
my_quantizer
=
te
::
pytorch
::
convert_quantizer
(
quantizer
);
my_quantizer
->
set_quantization_params
(
&
_ubuf
);
for
(
size_t
i
=
0
;
i
<
_ubufs
.
size
();
i
++
)
my_quantizer
->
set_quantization_params
(
&
_ubufs
[
i
]);
}
/*
** Copy input to _ubufs[0]
*/
void
CommOverlapP2P
::
copy_into_buffer
(
py
::
handle
input
,
py
::
handle
quantizer
,
bool
local_chunk
)
{
auto
input_tensor
=
te
::
pytorch
::
makeTransformerEngineTensor
(
input
,
quantizer
);
auto
input_ptr
=
input_tensor
.
dptr
()
?
input_tensor
.
dptr
()
:
input_tensor
.
columnwise_dptr
();
NVTE_CHECK
(
input_ptr
,
"Input tensor does not have rowwise or columnwise data!"
);
at
::
cuda
::
CUDAStream
stream_main
=
at
::
cuda
::
getCurrentCUDAStream
();
void
CommOverlapP2P
::
copy_into_buffer
(
const
at
::
Tensor
&
input
,
bool
local_chunk
)
{
const
auto
&
input_
=
input
.
contiguous
();
// Check element size
const
size_t
element_size
=
input
.
element_size
();
NVTE_CHECK
(
_ubuf
.
element_size
()
==
element_size
,
"Tried to copy data into a Userbuffers buffer but dtypes are not compatible "
,
"(input dtype has "
,
element_size
,
" bytes, UB dtype has "
,
_ubuf
.
element_size
(),
" bytes)"
);
// Input data
const
size_t
input_size
=
input_
.
numel
();
const
void
*
src_ptr
=
input_
.
data_ptr
();
// Userbuffers data
void
*
dst_ptr
;
if
(
local_chunk
)
{
// Copy input to the target ubuf chunk by rank offset
if
(
input_tensor
.
numel
()
*
_tp_size
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the local communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubufs
[
_tp_id
].
dptr
(),
input_ptr
,
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
stream_main
));
NVTE_CHECK
(
_ubufs
[
_tp_id
].
numel
()
==
input_size
,
"Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer "
,
"(input_size="
,
input_size
,
", local_ubuf_size="
,
_ubufs
[
_tp_id
].
numel
(),
")"
);
dst_ptr
=
_ubufs
[
_tp_id
].
dptr
();
}
else
{
if
(
input_tensor
.
numel
()
>
_ubuf
.
numel
())
NVTE_ERROR
(
"input is larger than the global communication buffer!"
);
if
(
input_tensor
.
element_size
()
!=
_ubuf
.
element_size
())
NVTE_ERROR
(
"input data type does not match communication buffer!"
);
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
_ubuf
.
dptr
(),
input_ptr
,
input_tensor
.
numel
()
*
input_tensor
.
element_size
(),
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
stream_main
));
NVTE_CHECK
(
_ubuf
.
numel
()
==
input_size
,
"Tried to copy an invalid tensor into a Userbuffers buffer "
,
"(input_size="
,
input_size
,
", ubuf_size="
,
_ubuf
.
numel
(),
")"
);
dst_ptr
=
_ubuf
.
dptr
();
}
// Copy data
NVTE_CHECK_CUDA
(
cudaMemcpyAsync
(
dst_ptr
,
src_ptr
,
input_size
*
element_size
,
cudaMemcpyDeviceToDevice
,
(
cudaStream_t
)
at
::
cuda
::
getCurrentCUDAStream
()));
}
py
::
object
CommOverlapP2P
::
get_buffer
(
py
::
handle
quantizer
,
bool
local_chunk
,
std
::
optional
<
const
std
::
vector
<
int64_t
>>
shape
)
{
using
namespace
te
::
pytorch
;
char
*
ubuf_wt_ptr
=
reinterpret_cast
<
char
*>
(
local_chunk
?
_ubufs
[
_tp_id
].
dptr
()
:
_ubuf
.
dptr
());
std
::
vector
<
int64_t
>
torch_shape
;
if
(
shape
.
has_value
())
{
torch_shape
=
shape
.
value
();
size_t
requested
=
product
(
torch_shape
);
auto
expected
=
local_chunk
?
_ubufs
[
_tp_id
].
numel
()
:
_ubuf
.
numel
();
NVTE_CHECK
(
requested
==
expected
,
"Number of elements in the requested shape ("
,
requested
,
") does not match allocated buffer size ("
,
expected
,
")!"
);
at
::
Tensor
CommOverlapP2P
::
get_buffer
(
bool
local_chunk
,
std
::
optional
<
std
::
vector
<
int64_t
>>
shape
)
{
// Check buffer shape
if
(
shape
)
{
const
size_t
requested_size
=
transformer_engine
::
pytorch
::
product
(
*
shape
);
if
(
local_chunk
)
{
NVTE_CHECK
(
requested_size
==
_ubufs
[
_tp_id
].
numel
(),
"Invalid shape for local chunk of a Userbuffers buffer (requested shape="
,
*
shape
,
", local_ubuf_size="
,
_ubufs
[
_tp_id
].
numel
(),
")"
);
}
else
{
int64_t
output_c_dim0
=
(
local_chunk
)
?
_ubuf
.
size
(
0
)
/
_tp
_size
:
_ubuf
.
size
(
0
);
int64_t
output_c_dim1
=
_ubuf
.
size
(
1
);
torch_shape
=
{
output_c_dim0
,
output_c_dim1
}
;
NVTE_CHECK
(
requested
_size
==
_ubuf
.
numel
(),
"Invalid shape for a Userbuffers buffer (requested shape="
,
*
shape
,
", ubuf_size="
,
_ubuf
.
numel
(),
")"
)
;
}
auto
ubuf_tensor
=
torch
::
from_blob
(
reinterpret_cast
<
void
*>
(
ubuf_wt_ptr
),
torch_shape
,
at
::
dtype
(
GetATenDType
(
_ubuf
.
dtype
())).
device
(
torch
::
kCUDA
));
std
::
unique_ptr
<
Quantizer
>
my_quantizer
=
convert_quantizer
(
quantizer
);
std
::
vector
<
size_t
>
te_shape
;
for
(
auto
s
:
torch_shape
)
te_shape
.
emplace_back
(
static_cast
<
size_t
>
(
s
));
// Always output a rowwise-only QuantizedTensor
// TODO (Alp): This needs to produce an un-interleaved transpose when required.
auto
is_internal
=
my_quantizer
->
internal
;
auto
uses_columnwise
=
my_quantizer
->
columnwise_usage
;
my_quantizer
->
internal
=
false
;
my_quantizer
->
columnwise_usage
=
false
;
auto
[
te_tensor
,
py_tensor
]
=
my_quantizer
->
create_tensor
(
te_shape
,
_ubuf
.
dtype
(),
ubuf_tensor
);
my_quantizer
->
internal
=
is_internal
;
my_quantizer
->
columnwise_usage
=
uses_columnwise
;
return
py_tensor
;
}
else
{
int64_t
dim0
=
_ubuf
.
size
(
0
);
int64_t
dim1
=
_ubuf
.
size
(
1
);
if
(
local_chunk
)
{
dim0
/=
_tp_size
;
}
shape
=
{
dim0
,
dim1
};
}
// Data pointer
void
*
ubuf_ptr
=
local_chunk
?
_ubufs
[
_tp_id
].
dptr
()
:
_ubuf
.
dptr
();
// Construct PyTorch tensor
const
auto
dtype
=
transformer_engine
::
pytorch
::
GetATenDType
(
_ubuf
.
dtype
());
return
torch
::
from_blob
(
ubuf_ptr
,
*
shape
,
at
::
dtype
(
dtype
).
device
(
torch
::
kCUDA
));
}
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
void
fp8_block_scaling_compute_partial_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
amax
,
size_t
h
,
size_t
w
,
size_t
start_offset
,
size_t
block_len
)
{
TORCH_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
TORCH_CHECK
(
amax
.
dim
()
==
2
,
"amax must be a 2D tensor"
);
TORCH_CHECK
(
amax
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"amax must be a float tensor"
);
TORCH_CHECK
(
tensor
.
scalar_type
()
==
at
::
ScalarType
::
Float
||
tensor
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"tensor must be a float or bfloat16 tensor"
);
const
TensorWrapper
tensor_cu
=
makeTransformerEngineTensor
(
tensor
);
TensorWrapper
amax_cu
=
makeTransformerEngineTensor
(
amax
);
nvte_fp8_block_scaling_compute_partial_amax
(
tensor_cu
.
data
(),
amax_cu
.
data
(),
h
,
w
,
amax
.
stride
(
0
),
amax
.
stride
(
1
),
start_offset
,
block_len
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
fp8_block_scaling_partial_cast
(
const
at
::
Tensor
&
inp
,
at
::
Tensor
out
,
const
at
::
Tensor
&
scale
,
size_t
h
,
size_t
w
,
size_t
start_offset
,
size_t
block_len
,
const
transformer_engine
::
DType
out_dtype
)
{
TORCH_CHECK
(
block_len
==
128
,
"Currently only block_len = 128 is supported"
);
TORCH_CHECK
(
scale
.
dim
()
==
2
,
"scale must be a 2D tensor"
);
TORCH_CHECK
(
scale
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"scale must be a float tensor"
);
TORCH_CHECK
(
inp
.
scalar_type
()
==
at
::
ScalarType
::
Float
||
inp
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
,
"input must be a float or bfloat16 tensor"
);
TORCH_CHECK
(
out
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"output must be a uint8 tensor"
);
TORCH_CHECK
(
out_dtype
==
transformer_engine
::
DType
::
kFloat8E4M3
||
out_dtype
==
transformer_engine
::
DType
::
kFloat8E5M2
,
"out_dtype must be kFloat8E4M3 or kFloat8E5M2"
);
const
TensorWrapper
inp_cu
=
makeTransformerEngineTensor
(
inp
);
TensorWrapper
out_cu
=
makeTransformerEngineTensor
(
out
);
const
TensorWrapper
scale_cu
=
makeTransformerEngineTensor
(
scale
);
nvte_fp8_block_scaling_partial_cast
(
inp_cu
.
data
(),
out_cu
.
data
(),
scale_cu
.
data
(),
h
,
w
,
scale
.
stride
(
0
),
scale
.
stride
(
1
),
start_offset
,
block_len
,
static_cast
<
NVTEDType
>
(
out_dtype
),
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/gemm.cpp
View file @
f8c2af4c
...
...
@@ -4,7 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#include <Python.h>
#include <pybind11/pybind11.h>
#include <optional>
...
...
@@ -21,12 +20,12 @@
namespace
{
void
*
get_data_ptr
(
MaybeTensor
tensor
)
{
void
*
get_data_ptr
(
transformer_engine
::
pytorch
::
MaybeTensor
tensor
)
{
if
(
tensor
.
has_value
())
return
tensor
->
data_ptr
();
return
nullptr
;
}
size_t
get_size
(
MaybeTensor
tensor
,
int
dim
)
{
size_t
get_size
(
transformer_engine
::
pytorch
::
MaybeTensor
tensor
,
int
dim
)
{
if
(
tensor
.
has_value
())
return
static_cast
<
size_t
>
(
tensor
->
size
(
dim
));
return
0
;
}
...
...
@@ -167,8 +166,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
makeTransformerEngineTensor
(
get_data_ptr
(
pre_gelu_out
),
gelu_shape
,
gelu_type
);
// Workspace
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
);
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
...
...
@@ -197,38 +196,52 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap
if
(
bulk_overlap
)
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
bulk_overlap
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
comm_type
.
value
(),
extra_output_tensor
,
main_stream
);
});
}
else
if
(
comm_type
.
value
()
==
CommOverlapType
::
AG
)
{
if
(
comm_overlap
->
is_atomic_gemm
())
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
atomic_gemm_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
});
}
else
{
comm_overlap
->
split_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
split_overlap_ag
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
});
}
}
else
{
if
(
comm_overlap
->
is_atomic_gemm
())
{
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
atomic_gemm_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
});
}
else
{
comm_overlap
->
split_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
NVTE_SCOPED_GIL_RELEASE
({
comm_overlap
->
split_overlap_rs
(
A_tensor
,
transa
,
B_tensor
,
transb
,
D_tensor
,
bias_tensor
,
te_pre_gelu_out
,
te_workspace
,
grad
,
accumulate
,
use_split_accumulator
,
extra_output_tensor
,
main_stream
);
});
}
}
}
else
{
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_gemm
(
A_tensor
.
data
(),
B_tensor
.
data
(),
D_tensor
.
data
(),
bias_tensor
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sms
,
main_stream
);
});
}
}
else
{
if
(
D_tensor
.
numel
()
!=
0
&&
!
accumulate
)
{
...
...
@@ -258,20 +271,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
return
out
;
}
}
// namespace transformer_engine::pytorch
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
transformer_engine
::
DType
A_type
,
void
te_atomic_gemm
(
at
::
Tensor
A
,
at
::
Tensor
A_scale_inverse
,
DType
A_type
,
std
::
vector
<
int64_t
>
A_scaling_mode
,
bool
transa
,
at
::
Tensor
B
,
at
::
Tensor
B_scale_inverse
,
transformer_engine
::
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
bool
transb
,
at
::
Tensor
D
,
at
::
Tensor
D_scale
,
transformer_engine
::
DType
D_type
,
at
::
Tensor
D_amax
,
at
::
Tensor
bias
,
transformer_engine
::
DType
bias_type
,
at
::
Tensor
pre_gelu_out
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
at
::
Tensor
B_scale_inverse
,
DType
B_type
,
std
::
vector
<
int64_t
>
B_scaling_mode
,
bool
transb
,
at
::
Tensor
D
,
at
::
Tensor
D_scale
,
DType
D_type
,
at
::
Tensor
D_amax
,
at
::
Tensor
bias
,
DType
bias_type
,
at
::
Tensor
pre_gelu_out
,
bool
grad
,
at
::
Tensor
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
int
m_split
,
int
n_split
,
bool
gemm_producer
,
at
::
Tensor
counter
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
// TODO: Handle scaling modes
NVTEScalingMode
nvte_scaling_modeA
=
NVTE_DELAYED_TENSOR_SCALING
;
NVTEScalingMode
nvte_scaling_modeB
=
NVTE_DELAYED_TENSOR_SCALING
;
...
...
@@ -286,12 +293,13 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
nvte_scaling_modeB
);
// TODO: D_scale_inv cannot be nullptr when D_type is FP8.
auto
te_D
=
makeTransformerEngineTensor
(
D
.
data_ptr
(),
{
static_cast
<
size_t
>
(
D
.
size
(
0
)),
static_cast
<
size_t
>
(
D
.
size
(
1
))},
D_type
,
D
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
D
.
size
(
0
)),
static_cast
<
size_t
>
(
D
.
size
(
1
))},
D_type
,
D_amax
.
data_ptr
(),
D_scale
.
data_ptr
(),
nullptr
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
.
data_ptr
(),
{
static_cast
<
size_t
>
(
bias
.
size
(
0
))},
bias_type
);
auto
te_bias
=
makeTransformerEngineTensor
(
bias
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
bias
.
size
(
0
))},
bias_type
);
auto
te_counter
=
makeTransformerEngineTensor
(
counter
.
data_ptr
(),
{
static_cast
<
size_t
>
(
counter
.
size
(
0
))},
DType
::
kInt32
);
counter
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
counter
.
size
(
0
))},
DType
::
kInt32
);
const
auto
gelu_shape
=
pre_gelu_out
.
data_ptr
()
==
nullptr
?
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
pre_gelu_out
.
size
(
0
))}
...
...
@@ -299,24 +307,23 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine
static_cast
<
size_t
>
(
pre_gelu_out
.
size
(
1
))};
auto
te_pre_gelu_out
=
makeTransformerEngineTensor
(
pre_gelu_out
.
data_ptr
(),
gelu_shape
,
GetTransformerEngineDType
(
pre_gelu_out
.
scalar_type
()));
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
);
auto
te_workspace
=
makeTransformerEngineTensor
(
workspace
.
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_cublas_atomic_gemm
(
te_A
.
data
(),
te_B
.
data
(),
te_D
.
data
(),
te_bias
.
data
(),
te_pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
te_workspace
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
m_split
,
n_split
,
gemm_producer
,
te_counter
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
te_general_grouped_gemm
(
std
::
vector
<
py
::
handle
>
A
,
bool
transa
,
std
::
vector
<
py
::
handle
>
B
,
bool
transb
,
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
transformer_engine
::
DType
D_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
transformer_engine
::
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
std
::
optional
<
std
::
vector
<
at
::
Tensor
>>
D
,
DType
D_type
,
std
::
vector
<
int64_t
>
m_splits
,
std
::
vector
<
at
::
Tensor
>
bias
,
DType
bias_type
,
bool
single_output
,
std
::
vector
<
at
::
Tensor
>
pre_gelu_out
,
bool
grad
,
std
::
vector
<
at
::
Tensor
>
workspace
,
size_t
workspaceSize
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
)
{
std
::
vector
<
NVTETensor
>
te_A_vector
,
te_B_vector
,
te_D_vector
,
te_bias_vector
,
te_pre_gelu_out_vector
,
te_workspace_vector
;
std
::
vector
<
TensorWrapper
>
wrappers
;
...
...
@@ -419,16 +426,19 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
wrappers
.
emplace_back
(
std
::
move
(
te_pre_gelu_out
));
}
for
(
size_t
i
=
0
;
i
<
workspace
.
size
();
i
++
)
{
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
{
workspaceSize
},
DType
::
kByte
);
auto
wsp
=
makeTransformerEngineTensor
(
workspace
[
i
].
data_ptr
(),
std
::
vector
<
size_t
>
{
workspaceSize
},
DType
::
kByte
);
te_workspace_vector
.
emplace_back
(
wsp
.
data
());
wrappers
.
emplace_back
(
std
::
move
(
wsp
));
}
// For now, we only have multi-stream cublas backend.
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_stream_cublas_gemm
(
te_A_vector
.
data
(),
te_B_vector
.
data
(),
te_D_vector
.
data
(),
te_bias_vector
.
data
(),
te_pre_gelu_out_vector
.
data
(),
te_A_vector
.
size
(),
transa
,
transb
,
grad
,
te_workspace_vector
.
data
(),
accumulate
,
use_split_accumulator
,
math_sm_count
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
bias
;
}
...
...
@@ -534,3 +544,5 @@ std::vector<at::Tensor> te_batchgemm_ts(
}
#endif
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/misc.cpp
View file @
f8c2af4c
...
...
@@ -6,6 +6,8 @@
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
#ifdef USE_ROCM
size_t
get_cublasLt_version
()
{
int
version
=
10000000
;
return
version
;
}
...
...
@@ -15,3 +17,5 @@ size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t
get_cudnn_version
()
{
return
cudnnGetVersion
();
}
#endif
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_param_remainder_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_param_remainder_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_fp8_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
DType
fp8_dtype
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_fp8_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr
,
beta1
,
beta2
,
epsilon
,
step
,
mode
,
bias_correction
,
weight_decay
,
static_cast
<
NVTEDType
>
(
fp8_dtype
),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_capturable_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
at
::
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
at
::
Tensor
inv_scale
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_capturable_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
inv_scale_cu
.
data
(),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
void
multi_tensor_adam_capturable_master_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
at
::
Tensor
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
,
at
::
Tensor
inv_scale
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
auto
lr_cu
=
makeTransformerEngineTensor
(
lr
);
auto
step_cu
=
makeTransformerEngineTensor
(
step
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_adam_capturable_master_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
lr_cu
.
data
(),
beta1
,
beta2
,
epsilon
,
step_cu
.
data
(),
mode
,
bias_correction
,
weight_decay
,
inv_scale_cu
.
data
(),
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
void
multi_tensor_compute_scale_and_scale_inv_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
max_fp8
,
bool
force_pow_2_scales
,
float
epsilon
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_compute_scale_and_scale_inv_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
max_fp8
,
force_pow_2_scales
,
epsilon
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
optional
<
bool
>
per_tensor_python
)
{
bool
per_tensor
=
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
auto
float_options
=
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
kFloat
);
auto
output
=
at
::
zeros
({
320
},
float_options
);
at
::
Tensor
output_per_tensor
;
at
::
Tensor
ret_per_tensor
;
auto
ret
=
at
::
empty
({
1
},
output
.
options
());
int
ntensors
=
tensor_lists
[
0
].
size
();
int
max_chunks_per_tensor
=
-
1
;
if
(
per_tensor
)
{
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
int
max_chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
if
(
max_chunks_this_tensor
>
max_chunks_per_tensor
)
max_chunks_per_tensor
=
max_chunks_this_tensor
;
}
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
ret_per_tensor
=
at
::
empty
({
ntensors
},
float_options
);
}
else
{
output_per_tensor
=
at
::
empty
({
0
},
float_options
);
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
auto
output_per_tensor_cu
=
makeTransformerEngineTensor
(
output_per_tensor
);
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
multi_tensor_unscale_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
inv_scale
,
at
::
optional
<
bool
>
per_tensor_python
)
{
bool
per_tensor
=
per_tensor_python
.
has_value
()
?
per_tensor_python
.
value
()
:
false
;
auto
float_options
=
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
kFloat
);
auto
output
=
at
::
zeros
({
320
},
float_options
);
at
::
Tensor
output_per_tensor
;
at
::
Tensor
ret_per_tensor
;
int
ntensors
=
tensor_lists
[
0
].
size
();
int
max_chunks_per_tensor
=
-
1
;
// Create output tensors for multi scale L2 norm kernel.
if
(
per_tensor
)
{
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
int
max_chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
if
(
max_chunks_this_tensor
>
max_chunks_per_tensor
)
max_chunks_per_tensor
=
max_chunks_this_tensor
;
}
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
ret_per_tensor
=
at
::
empty
({
ntensors
},
float_options
);
}
else
{
output_per_tensor
=
at
::
empty
({
0
},
float_options
);
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
auto
ret
=
at
::
empty
({
1
},
output
.
options
());
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
auto
output_cu
=
makeTransformerEngineTensor
(
output
);
auto
output_per_tensor_cu
=
makeTransformerEngineTensor
(
output_per_tensor
);
auto
ret_cu
=
makeTransformerEngineTensor
(
ret
);
auto
ret_per_tensor_cu
=
makeTransformerEngineTensor
(
ret_per_tensor
);
auto
inv_scale_cu
=
makeTransformerEngineTensor
(
inv_scale
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_unscale_l2norm_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
output_cu
.
data
(),
output_per_tensor_cu
.
data
(),
ret_cu
.
data
(),
ret_per_tensor_cu
.
data
(),
inv_scale_cu
.
data
(),
per_tensor
,
max_chunks_per_tensor
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
return
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
(
ret
,
ret_per_tensor
);
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
void
multi_tensor_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_scale_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
scale
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp
0 → 100644
View file @
f8c2af4c
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
void
multi_tensor_sgd_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
wd
,
float
momentum
,
float
dampening
,
float
lr
,
bool
nesterov
,
bool
first_run
,
bool
wd_after_momentum
,
float
scale
)
{
auto
noop_flag_cu
=
makeTransformerEngineTensor
(
noop_flag
);
auto
[
_
,
__
,
tensor_lists_ptr
,
num_lists
,
num_tensors
]
=
makeTransformerEngineTensorList
(
tensor_lists
);
int
device_id
=
tensor_lists
[
0
][
0
].
device
().
index
();
nvte_multi_tensor_sgd_cuda
(
chunk_size
,
noop_flag_cu
.
data
(),
tensor_lists_ptr
.
data
(),
num_lists
,
num_tensors
,
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
,
device_id
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/normalization.cpp
View file @
f8c2af4c
...
...
@@ -9,28 +9,11 @@
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
std
::
pair
<
TensorWrapper
,
py
::
object
>
createOutputTensor
(
const
NVTEShape
&
shape
,
DType
dtype
,
py
::
handle
quantizer
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
i
++
)
{
size_t
t
=
shape
.
data
[
i
];
shape_vec
.
push_back
(
t
);
}
std
::
unique_ptr
<
Quantizer
>
my_quantizer
=
convert_quantizer
(
quantizer
);
return
my_quantizer
->
create_tensor
(
shape_vec
,
dtype
);
}
std
::
pair
<
TensorWrapper
,
py
::
object
>
createOutputTensor
(
std
::
vector
<
size_t
>
&
shape
,
DType
dtype
,
py
::
handle
quantizer
)
{
std
::
unique_ptr
<
Quantizer
>
my_quantizer
=
convert_quantizer
(
quantizer
);
return
my_quantizer
->
create_tensor
(
shape
,
dtype
);
}
}
// namespace transformer_engine::pytorch
std
::
vector
<
py
::
object
>
layernorm_bwd
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
mu
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
;
const
auto
&
dz_
=
dz
.
contiguous
();
const
auto
&
x_
=
x
.
contiguous
();
const
auto
&
mu_
=
mu
.
contiguous
();
...
...
@@ -40,7 +23,7 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto
dx
=
at
::
empty_like
(
x_
);
auto
dgamma
=
at
::
empty_like
(
gamma_
);
auto
dbeta
=
at
::
empty_like
(
gamma_
);
transformer_engine
::
TensorWrapper
workspace
;
TensorWrapper
workspace
;
auto
dz_cu
=
makeTransformerEngineTensor
(
dz_
);
auto
x_cu
=
makeTransformerEngineTensor
(
x_
);
...
...
@@ -52,10 +35,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto
dbeta_cu
=
makeTransformerEngineTensor
(
dbeta
);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_bwd
(
dz_cu
.
data
(),
x_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
gamma_cu
.
data
(),
dx_cu
.
data
(),
dgamma_cu
.
data
(),
dbeta_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Alloc space for Tensors.
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
...
...
@@ -63,10 +48,12 @@ std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_bwd
(
dz_cu
.
data
(),
x_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
gamma_cu
.
data
(),
dx_cu
.
data
(),
dgamma_cu
.
data
(),
dbeta_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
dx
),
py
::
cast
(
dgamma
),
py
::
cast
(
dbeta
)};
}
...
...
@@ -76,8 +63,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
;
// Input and param tensors
auto
none
=
py
::
none
();
...
...
@@ -131,11 +116,13 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
TensorWrapper
&
kernel_out_cu
=
force_unfused_kernel
?
unquantized_out_cu
:
out_cu
;
// Query workspace size
transformer_engine
::
TensorWrapper
workspace
;
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
bias_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Allocate workspace
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
...
...
@@ -143,10 +130,12 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_layernorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
bias_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
...
...
@@ -154,7 +143,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
...
...
@@ -169,7 +161,9 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
}
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
else
if
(
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
...
...
@@ -177,8 +171,10 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
return
{
out
,
py
::
cast
(
mu
),
py
::
cast
(
rsigma
)};
...
...
@@ -187,7 +183,6 @@ std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, Maybe
std
::
vector
<
py
::
object
>
rmsnorm_bwd
(
const
at
::
Tensor
&
dz
,
const
at
::
Tensor
&
x
,
const
at
::
Tensor
&
rsigma
,
const
at
::
Tensor
&
gamma
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
;
const
auto
&
dz_
=
dz
.
contiguous
();
const
auto
&
x_
=
x
.
contiguous
();
const
auto
&
rsigma_
=
rsigma
.
contiguous
();
...
...
@@ -195,7 +190,7 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto
dx
=
at
::
empty_like
(
x_
);
auto
dgamma
=
at
::
empty_like
(
gamma_
);
transformer_engine
::
TensorWrapper
workspace
;
TensorWrapper
workspace
;
auto
dz_cu
=
makeTransformerEngineTensor
(
dz_
);
auto
x_cu
=
makeTransformerEngineTensor
(
x_
);
...
...
@@ -205,10 +200,12 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
auto
dgamma_cu
=
makeTransformerEngineTensor
(
dgamma
);
// This call populates tensors with the required config.
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_bwd
(
dz_cu
.
data
(),
x_cu
.
data
(),
rsigma_cu
.
data
(),
gamma_cu
.
data
(),
dx_cu
.
data
(),
dgamma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Alloc space for Tensors.
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
...
...
@@ -216,21 +213,20 @@ std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Actual call to bwd kernel.
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_bwd
(
dz_cu
.
data
(),
x_cu
.
data
(),
rsigma_cu
.
data
(),
gamma_cu
.
data
(),
dx_cu
.
data
(),
dgamma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
return
{
py
::
cast
(
dx
),
py
::
cast
(
dgamma
)};
}
std
::
vector
<
py
::
object
>
rmsnorm_fwd
(
const
py
::
handle
&
input
,
const
py
::
handle
&
weight
,
float
eps
,
py
::
object
out
,
py
::
handle
quantizer
,
transformer_engine
::
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
py
::
object
out
,
py
::
handle
quantizer
,
DType
out_dtype
,
const
int
sm_margin
,
const
bool
zero_centered_gamma
)
{
using
namespace
transformer_engine
::
pytorch
::
detail
;
using
namespace
transformer_engine
::
pytorch
;
using
namespace
transformer_engine
;
// Input and param tensors
auto
none
=
py
::
none
();
...
...
@@ -278,11 +274,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
TensorWrapper
&
kernel_out_cu
=
force_unfused_kernel
?
unquantized_out_cu
:
out_cu
;
// Query workspace size
transformer_engine
::
TensorWrapper
workspace
;
TensorWrapper
workspace
;
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Allocate workspace
auto
workspace_data
=
allocateSpace
(
workspace
.
shape
(),
workspace
.
dtype
());
...
...
@@ -290,10 +288,12 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
makeTransformerEngineTensor
(
workspace_data
.
data_ptr
(),
workspace
.
shape
(),
workspace
.
dtype
());
// Launch kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_rmsnorm_fwd
(
input_cu
.
data
(),
weight_cu
.
data
(),
eps
,
kernel_out_cu
.
data
(),
rsigma_cu
.
data
(),
workspace
.
data
(),
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
-
sm_margin
,
zero_centered_gamma
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// Quantize output if using unfused kernel
if
(
force_unfused_kernel
)
{
...
...
@@ -301,7 +301,10 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
if
(
IsFloat8CurrentScalingQuantizers
(
quantizer
.
ptr
()))
{
// my_quantizer here has to be a Float8CurrentScalingQuantizer
auto
my_quantizer_cs
=
static_cast
<
Float8CurrentScalingQuantizer
*>
(
my_quantizer
.
get
());
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_amax
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
// check if we need to do amax reudction (depending on model parallel configs)
if
(
my_quantizer_cs
->
with_amax_reduction
)
{
c10
::
intrusive_ptr
<
dist_group_type
>
process_group_ptr
=
...
...
@@ -316,7 +319,9 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
}
quant_config
.
set_force_pow_2_scales
(
my_quantizer_cs
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_cs
->
amax_epsilon
);
NVTE_SCOPED_GIL_RELEASE
({
nvte_compute_scale_from_amax
(
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
// set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel
out_cu
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
out_cu
.
defaultShape
);
}
else
if
(
IsFloat8BlockwiseQuantizers
(
quantizer
.
ptr
()))
{
...
...
@@ -324,9 +329,13 @@ std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &w
quant_config
.
set_force_pow_2_scales
(
my_quantizer_bw
->
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
my_quantizer_bw
->
amax_epsilon
);
}
NVTE_SCOPED_GIL_RELEASE
({
nvte_quantize_v2
(
unquantized_out_cu
.
data
(),
out_cu
.
data
(),
quant_config
,
at
::
cuda
::
getCurrentCUDAStream
());
});
}
return
{
out
,
py
::
none
(),
py
::
cast
(
rsigma
)};
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/nvshmem_comm.cpp
View file @
f8c2af4c
...
...
@@ -17,7 +17,8 @@
#include <torch/cuda.h>
#include <torch/extension.h>
namespace
nvshmem_api
{
namespace
transformer_engine
::
pytorch
{
void
init_nvshmem_backend
(
c10d
::
ProcessGroup
*
process_group
)
{
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t
attr
=
{};
...
...
@@ -126,4 +127,5 @@ void nvshmem_finalize() {
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!"
);
#endif
}
}
// namespace nvshmem_api
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/padding.cpp
View file @
f8c2af4c
...
...
@@ -5,13 +5,13 @@
************************************************************************/
#include "extensions.h"
#include "pybind.h"
namespace
transformer_engine
::
pytorch
{
void
fused_multi_row_padding
(
at
::
Tensor
input
,
at
::
Tensor
output
,
std
::
vector
<
size_t
>
input_row_list
,
std
::
vector
<
size_t
>
padded_input_row_list
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
NVTE_CHECK
(
input_row_list
.
size
()
==
padded_input_row_list
.
size
(),
"Number of input row list and padded row list must match."
);
NVTE_CHECK
(
input
.
dim
()
==
2
,
"Dimension of input must equal 2."
);
...
...
@@ -21,7 +21,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Extract properties from PyTorch tensors
std
::
vector
<
void
*>
input_dptr_list
,
output_dptr_list
;
std
::
vector
<
std
::
vector
<
size_t
>>
input_shape_list
,
output_shape_list
;
std
::
vector
<
transformer_engine
::
DType
>
input_type_list
;
std
::
vector
<
DType
>
input_type_list
;
void
*
d_input_ptr
=
reinterpret_cast
<
void
*>
(
input
.
data_ptr
());
void
*
d_output_ptr
=
reinterpret_cast
<
void
*>
(
output
.
data_ptr
());
for
(
size_t
tensor_id
=
0
;
tensor_id
<
num_tensors
;
++
tensor_id
)
{
...
...
@@ -51,9 +51,9 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
// Construct TE tensors
std
::
vector
<
NVTETensor
>
nvte_input_list
,
nvte_output_list
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
tensor_wrappers
;
std
::
vector
<
TensorWrapper
>
tensor_wrappers
;
auto
make_tensor
=
[
&
tensor_wrappers
](
void
*
dptr
,
const
std
::
vector
<
size_t
>&
shape
,
transformer_engine
::
DType
dtype
)
->
NVTETensor
{
DType
dtype
)
->
NVTETensor
{
tensor_wrappers
.
emplace_back
(
makeTransformerEngineTensor
(
dptr
,
shape
,
dtype
));
return
tensor_wrappers
.
back
().
data
();
};
...
...
@@ -75,6 +75,10 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
"Number of input and padded row list must match"
);
// Launch TE kernel
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_padding
(
nvte_input_list
.
size
(),
nvte_input_list
.
data
(),
nvte_output_list
.
data
(),
padded_num_rows_list
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/permutation.c
u
→
transformer_engine/pytorch/csrc/extensions/permutation.c
pp
View file @
f8c2af4c
...
...
@@ -4,14 +4,13 @@
* See LICENSE for license information.
************************************************************************/
#include <cub/cub.cuh>
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
vector
<
at
::
Tensor
>>
moe_permute_fwd
(
at
::
Tensor
input
,
const
transformer_engine
::
DType
dtype
,
at
::
Tensor
indices
,
int64_t
num_out_tokens
,
std
::
vector
<
at
::
Tensor
>
workspace
,
int64_t
max_expanded_token_num
)
{
using
namespace
transformer_engine
::
pytorch
;
at
::
Tensor
input
,
const
DType
dtype
,
at
::
Tensor
indices
,
int64_t
num_out_tokens
,
std
::
vector
<
at
::
Tensor
>
workspace
,
int64_t
max_expanded_token_num
)
{
const
int
num_tokens
=
input
.
size
(
0
);
int
num_cols
=
input
.
size
(
1
);
const
int
topK
=
indices
.
size
(
1
);
...
...
@@ -28,9 +27,8 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
torch
::
dtype
(
torch
::
kInt32
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
size_t
temp_storage_bytes
=
0
;
int
*
temp_ptr
=
nullptr
;
cub
::
DeviceRadixSort
::
SortPairs
(
nullptr
,
temp_storage_bytes
,
temp_ptr
,
temp_ptr
,
temp_ptr
,
temp_ptr
,
max_expanded_token_num
);
nvte_device_radix_sort_pairs
(
nullptr
,
&
temp_storage_bytes
,
nullptr
,
nullptr
,
nullptr
,
nullptr
,
max_expanded_token_num
);
at
::
Tensor
temp_storage
=
torch
::
empty
(
temp_storage_bytes
,
torch
::
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
...
...
@@ -40,17 +38,18 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
workspace
.
push_back
(
temp_storage
);
}
int
*
indices_ptr
=
reinterpret_cast
<
int
*>
(
getDataPtr
(
indices
,
0
)
)
;
int
*
sorted_indices_ptr
=
reinterpret_cast
<
int
*>
(
getDataPtr
(
workspace
[
0
],
0
)
)
;
int
*
row_id_ptr
=
reinterpret_cast
<
int
*>
(
getDataPtr
(
workspace
[
1
],
0
)
)
;
int
*
sorted_row_id_ptr
=
reinterpret_cast
<
int
*>
(
getDataPtr
(
workspace
[
2
],
0
)
)
;
void
*
indices_ptr
=
getDataPtr
(
indices
,
0
);
void
*
sorted_indices_ptr
=
getDataPtr
(
workspace
[
0
],
0
);
void
*
row_id_ptr
=
getDataPtr
(
workspace
[
1
],
0
);
void
*
sorted_row_id_ptr
=
getDataPtr
(
workspace
[
2
],
0
);
void
*
d_temp_storage
=
getDataPtr
(
workspace
[
3
],
0
);
size_t
temp_storage_bytes
=
std
::
numeric_limits
<
size_t
>::
max
();
cub
::
DeviceRadixSort
::
SortPairs
(
d_temp_storage
,
temp_storage_bytes
,
indices_ptr
,
sorted_indices_ptr
,
row_id_ptr
,
sorted_row_id_ptr
,
num_tokens
*
topK
);
nvte_device_radix_sort_pairs
(
d_temp_storage
,
&
temp_storage_bytes
,
reinterpret_cast
<
int
*>
(
indices_ptr
),
reinterpret_cast
<
int
*>
(
sorted_indices_ptr
),
reinterpret_cast
<
int
*>
(
row_id_ptr
),
reinterpret_cast
<
int
*>
(
sorted_row_id_ptr
),
num_tokens
*
topK
);
// Output buffer alloc
num_out_tokens
=
(
num_out_tokens
>
0
)
?
num_out_tokens
:
num_tokens
*
topK
;
...
...
@@ -63,34 +62,33 @@ std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
input_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
{
static_cast
<
size_t
>
(
input
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
permuted_output_cu
=
makeTransformerEngineTensor
(
permuted_output
.
data_ptr
(),
{
static_cast
<
size_t
>
(
permuted_output
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
sorted_row_id_cu
=
makeTransformerEngineTensor
(
sorted_row_id_ptr
,
{
static_cast
<
size_t
>
(
num_tokens
*
topK
)},
transformer_engine
::
DType
::
kInt32
);
input
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
input
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
permuted_output_cu
=
makeTransformerEngineTensor
(
permuted_output
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
permuted_output
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
sorted_row_id_cu
=
makeTransformerEngineTensor
(
sorted_row_id_ptr
,
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
num_tokens
*
topK
)},
DType
::
kInt32
);
auto
row_id_map_cu
=
makeTransformerEngineTensor
(
row_id_map
);
nvte_permute
(
input_cu
.
data
(),
permuted_output_cu
.
data
(),
sorted_row_id_cu
.
data
(),
row_id_map_cu
.
data
(),
transformer_engine
::
TensorWrapper
().
data
(),
transformer_engine
::
TensorWrapper
().
data
(),
transformer_engine
::
TensorWrapper
().
data
(),
num_tokens
,
topK
,
num_cols
,
num_out_tokens
,
stream
);
row_id_map_cu
.
data
(),
TensorWrapper
().
data
(),
TensorWrapper
().
data
(),
TensorWrapper
().
data
(),
num_tokens
,
topK
,
num_cols
,
num_out_tokens
,
stream
);
return
std
::
make_tuple
(
permuted_output
,
row_id_map
,
workspace
);
}
at
::
Tensor
moe_permute_bwd
(
at
::
Tensor
input
,
const
transformer_engine
::
DType
dtype
,
at
::
Tensor
row_id_map
,
at
::
Tensor
prob
,
int64_t
num_tokens
,
int64_t
topK
)
{
at
::
Tensor
moe_permute_bwd
(
at
::
Tensor
input
,
const
DType
dtype
,
at
::
Tensor
row_id_map
,
at
::
Tensor
prob
,
int64_t
num_tokens
,
int64_t
topK
)
{
return
moe_unpermute_fwd
(
input
,
dtype
,
row_id_map
,
prob
,
num_tokens
,
topK
);
}
at
::
Tensor
moe_unpermute_fwd
(
at
::
Tensor
input
,
const
transformer_engine
::
DType
dtype
,
at
::
Tensor
row_id_map
,
at
::
Tensor
prob
,
int64_t
num_tokens
,
int64_t
topK
)
{
using
namespace
transformer_engine
::
pytorch
;
at
::
Tensor
moe_unpermute_fwd
(
at
::
Tensor
input
,
const
DType
dtype
,
at
::
Tensor
row_id_map
,
at
::
Tensor
prob
,
int64_t
num_tokens
,
int64_t
topK
)
{
int
num_cols
=
input
.
size
(
1
);
// Output buffer alloc
...
...
@@ -101,10 +99,14 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
input_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
{
static_cast
<
size_t
>
(
input
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
input
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
input
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
unpermuted_output_cu
=
makeTransformerEngineTensor
(
unpermuted_output
.
data_ptr
(),
{
static_cast
<
size_t
>
(
unpermuted_output
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
unpermuted_output
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
row_id_map_cu
=
makeTransformerEngineTensor
(
row_id_map
);
auto
prob_cu
=
makeTransformerEngineTensor
(
prob
);
...
...
@@ -115,9 +117,8 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
moe_unpermute_bwd
(
at
::
Tensor
input_bwd
,
at
::
Tensor
input_fwd
,
const
transformer_engine
::
DType
dtype
,
at
::
Tensor
row_id_map
,
at
::
Tensor
prob
)
{
using
namespace
transformer_engine
::
pytorch
;
const
DType
dtype
,
at
::
Tensor
row_id_map
,
at
::
Tensor
prob
)
{
const
int
topK
=
(
prob
.
numel
()
>
0
)
?
prob
.
size
(
1
)
:
1
;
const
int
num_tokens
=
(
prob
.
numel
()
>
0
)
?
prob
.
size
(
0
)
:
row_id_map
.
size
(
0
);
int
num_cols
=
input_bwd
.
size
(
1
);
...
...
@@ -132,21 +133,26 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
input_bwd_cu
=
makeTransformerEngineTensor
(
input_bwd
.
data_ptr
(),
{
static_cast
<
size_t
>
(
input_bwd
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
input_bwd
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
input_bwd
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
act_grad_cu
=
makeTransformerEngineTensor
(
act_grad
.
data_ptr
(),
{
static_cast
<
size_t
>
(
act_grad
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
act_grad
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
act_grad
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
input_fwd_cu
=
makeTransformerEngineTensor
(
input_fwd
.
data_ptr
(),
{
static_cast
<
size_t
>
(
input_fwd
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
input_fwd
.
data_ptr
(),
std
::
vector
<
size_t
>
{
static_cast
<
size_t
>
(
input_fwd
.
size
(
0
)),
static_cast
<
size_t
>
(
num_cols
)},
dtype
);
auto
row_id_map_cu
=
makeTransformerEngineTensor
(
row_id_map
);
auto
prob_cu
=
makeTransformerEngineTensor
(
prob
);
auto
prob_grad_cu
=
makeTransformerEngineTensor
(
prob_grad
);
nvte_permute
(
input_bwd_cu
.
data
(),
act_grad_cu
.
data
(),
transformer_engine
::
TensorWrapper
().
data
(),
nvte_permute
(
input_bwd_cu
.
data
(),
act_grad_cu
.
data
(),
TensorWrapper
().
data
(),
row_id_map_cu
.
data
(),
prob_cu
.
data
(),
prob_grad_cu
.
data
(),
input_fwd_cu
.
data
(),
num_tokens
,
topK
,
num_cols
,
0
,
stream
);
return
std
::
make_tuple
(
act_grad
,
prob_grad
);
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/pybind.cpp
View file @
f8c2af4c
...
...
@@ -6,7 +6,6 @@
#include "pybind.h"
#include <Python.h>
#include <pybind11/cast.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
...
...
@@ -111,10 +110,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"workspace_size"
),
py
::
arg
(
"accumulate"
),
py
::
arg
(
"use_split_accumulator"
),
py
::
arg
(
"comm_overlap"
)
=
nullptr
,
py
::
arg
(
"comm_type"
)
=
std
::
nullopt
,
py
::
arg
(
"extra_output"
)
=
std
::
nullopt
,
py
::
arg
(
"bulk_overlap"
)
=
false
);
m
.
def
(
"rowwise_swizzle"
,
&
rowwise_swizzle
,
"Swizzle rowwise scale inverses."
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"columnwise_swizzle"
,
&
columnwise_swizzle
,
"Swizzle columnwise scale inverses."
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"gelu"
,
transformer_engine
::
pytorch
::
gelu
,
"GeLU activation"
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
));
m
.
def
(
"relu"
,
transformer_engine
::
pytorch
::
relu
,
"ReLU activation"
,
py
::
arg
(
"input"
),
...
...
@@ -160,85 +155,111 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"quantizer"
));
// Permutation functions
m
.
def
(
"moe_permute_fwd"
,
moe_permute_fwd
);
m
.
def
(
"moe_permute_bwd"
,
moe_permute_bwd
);
m
.
def
(
"moe_unpermute_fwd"
,
moe_unpermute_fwd
);
m
.
def
(
"moe_unpermute_bwd"
,
moe_unpermute_bwd
);
// Softmax functions
m
.
def
(
"scaled_softmax_forward"
,
&
scaled_softmax_forward
,
"Scaled Softmax FWD"
,
m
.
def
(
"moe_permute_fwd"
,
transformer_engine
::
pytorch
::
moe_permute_fwd
,
"MOE permute FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"
scaled_softmax_backward"
,
&
scaled_softmax_backward
,
"Scaled Softmax
BWD"
,
m
.
def
(
"
moe_permute_bwd"
,
transformer_engine
::
pytorch
::
moe_permute_bwd
,
"MOE permute
BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_masked_softmax_forward"
,
&
scaled_masked_softmax_forward
,
"Scaled Masked Softmax FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_masked_softmax_backward"
,
&
scaled_masked_softmax_backward
,
"Scaled Masked Softmax BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_upper_triang_masked_softmax_forward"
,
&
scaled_upper_triang_masked_softmax_forward
,
m
.
def
(
"moe_unpermute_fwd"
,
transformer_engine
::
pytorch
::
moe_unpermute_fwd
,
"MOE unpermute FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"moe_unpermute_bwd"
,
transformer_engine
::
pytorch
::
moe_unpermute_bwd
,
"MOE unpermute BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Softmax functions
m
.
def
(
"scaled_softmax_forward"
,
&
transformer_engine
::
pytorch
::
scaled_softmax_forward
,
"Scaled Softmax FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_softmax_backward"
,
&
transformer_engine
::
pytorch
::
scaled_softmax_backward
,
"Scaled Softmax BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_masked_softmax_forward"
,
&
transformer_engine
::
pytorch
::
scaled_masked_softmax_forward
,
"Scaled Masked Softmax FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_masked_softmax_backward"
,
&
transformer_engine
::
pytorch
::
scaled_masked_softmax_backward
,
"Scaled Masked Softmax BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_upper_triang_masked_softmax_forward"
,
&
transformer_engine
::
pytorch
::
scaled_upper_triang_masked_softmax_forward
,
"Scaled Upper-Triangular Masked Softmax FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_upper_triang_masked_softmax_backward"
,
&
scaled_upper_triang_masked_softmax_backward
,
m
.
def
(
"scaled_upper_triang_masked_softmax_backward"
,
&
transformer_engine
::
pytorch
::
scaled_upper_triang_masked_softmax_backward
,
"Scaled Upper-Triangular Masked Softmax BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_aligned_causal_masked_softmax_forward"
,
&
scaled_aligned_causal_masked_softmax_forward
,
&
transformer_engine
::
pytorch
::
scaled_aligned_causal_masked_softmax_forward
,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"scaled_aligned_causal_masked_softmax_backward"
,
&
scaled_aligned_causal_masked_softmax_backward
,
&
transformer_engine
::
pytorch
::
scaled_aligned_causal_masked_softmax_backward
,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Other granular functions
m
.
def
(
"layernorm_fwd"
,
&
layernorm_fwd
,
"LayerNorm"
,
py
::
arg
(
"input"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"layernorm_bwd"
,
&
layernorm_bwd
,
"Backward of LayerNorm"
);
m
.
def
(
"rmsnorm_fwd"
,
&
rmsnorm_fwd
,
"RMSNorm"
,
py
::
arg
(
"input"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"
ln_ou
t"
),
py
::
arg
(
"
quantizer
"
),
py
::
arg
(
"
otype
"
),
py
::
arg
(
"
sm_margin
"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"rmsnorm_bwd"
,
&
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"layernorm_fwd"
,
&
transformer_engine
::
pytorch
::
layernorm_fwd
,
"LayerNorm"
,
py
::
arg
(
"input"
),
py
::
arg
(
"weight"
),
py
::
arg
(
"bias"
),
py
::
arg
(
"eps"
),
py
::
arg
(
"ln_out"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"layernorm_bwd"
,
&
transformer_engine
::
pytorch
::
layernorm_bwd
,
"Backward of LayerNorm"
);
m
.
def
(
"rmsnorm_fwd"
,
&
transformer_engine
::
pytorch
::
rmsnorm_fwd
,
"RMSNorm"
,
py
::
arg
(
"input"
),
py
::
arg
(
"
weigh
t"
),
py
::
arg
(
"
eps
"
),
py
::
arg
(
"
ln_out
"
),
py
::
arg
(
"
quantizer
"
),
py
::
arg
(
"otype"
),
py
::
arg
(
"sm_margin"
),
py
::
arg
(
"zero_centered_gamma"
));
m
.
def
(
"rmsnorm_bwd"
,
&
transformer_engine
::
pytorch
::
rmsnorm_bwd
,
"Backward of RMSNorm"
);
m
.
def
(
"fused_multi_quantize"
,
&
transformer_engine
::
pytorch
::
fused_multi_quantize
,
"Fused Multi-tensor Cast + Transpose"
,
py
::
arg
(
"input_list"
),
py
::
arg
(
"output_list"
),
py
::
arg
(
"quantizer_list"
),
py
::
arg
(
"otype"
));
m
.
def
(
"te_general_grouped_gemm"
,
&
te_general_grouped_gemm
,
"Grouped GEMM"
);
m
.
def
(
"te_general_grouped_gemm"
,
&
transformer_engine
::
pytorch
::
te_general_grouped_gemm
,
"Grouped GEMM"
);
#ifdef USE_ROCM
m
.
def
(
"te_batchgemm_ts"
,
&
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
m
.
def
(
"te_batchgemm_ts"
,
&
transformer_engine
::
pytorch
::
te_batchgemm_ts
,
"Batched GEMM"
);
/// rocblas
#endif
m
.
def
(
"fp8_transpose"
,
&
transformer_engine
::
pytorch
::
fp8_transpose
,
"Transpose with FP8 I/O"
,
py
::
arg
(
"input"
),
py
::
arg
(
"dtype"
),
py
::
kw_only
(),
py
::
arg
(
"out"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_fused_attn_backend"
,
&
get_fused_attn_backend
,
"Get Fused Attention backend"
,
m
.
def
(
"get_fused_attn_backend"
,
&
transformer_engine
::
pytorch
::
get_fused_attn_backend
,
"Get Fused Attention backend"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"compute_amax"
,
&
transformer_engine
::
pytorch
::
compute_amax
,
"Compute absolute max value in tensor"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"
compute_amax"
,
&
compute_amax
,
"Compute amax"
,
py
::
arg
(
"input"
),
py
::
arg
(
"amax"
));
m
.
def
(
"fused_amax_and_scale_update_after_reduction"
,
&
fused_amax_and_scale_update_after_reduction
,
m
.
def
(
"
fused_amax_and_scale_update_after_reduction"
,
&
transformer_engine
::
pytorch
::
fused_amax_and_scale_update_after_reduction
,
"Update amax history and FP8 scale/scale_inv after reduction"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_padding"
,
&
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fp8_block_scaling_compute_partial_amax"
,
&
transformer_engine
::
pytorch
::
fp8_block_scaling_compute_partial_amax
,
"Compute partial amax from master weights for fp8 block scaling"
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"amax"
),
py
::
arg
(
"h"
),
py
::
arg
(
"w"
),
py
::
arg
(
"start_offset"
),
py
::
arg
(
"block_len"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fp8_block_scaling_partial_cast"
,
&
transformer_engine
::
pytorch
::
fp8_block_scaling_partial_cast
,
"Partial cast from master weights for fp8 block scaling"
,
py
::
arg
(
"inp"
),
py
::
arg
(
"out"
),
py
::
arg
(
"scale"
),
py
::
arg
(
"h"
),
py
::
arg
(
"w"
),
py
::
arg
(
"start_offset"
),
py
::
arg
(
"block_len"
),
py
::
arg
(
"out_dtype"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_multi_row_padding"
,
&
transformer_engine
::
pytorch
::
fused_multi_row_padding
,
"Fused Multi-tensor padding"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// attention kernels
m
.
def
(
"fa_prepare_fwd"
,
&
fa_prepare_fwd
,
"Prepare QKV for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_bwd"
,
&
fa_prepare_bwd
,
"Backward of QKV preparation for Flash Attention"
,
m
.
def
(
"fa_prepare_fwd"
,
&
transformer_engine
::
pytorch
::
fa_prepare_fwd
,
"Prepare QKV for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fa_prepare_bwd"
,
&
transformer_engine
::
pytorch
::
fa_prepare_bwd
,
"Backward of QKV preparation for Flash Attention"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_attn_fwd"
,
&
fused_attn_fwd
,
m
.
def
(
"fused_attn_fwd"
,
&
transformer_engine
::
pytorch
::
fused_attn_fwd
,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"
);
m
.
def
(
"fused_attn_bwd"
,
&
fused_attn_bwd
,
m
.
def
(
"fused_attn_bwd"
,
&
transformer_engine
::
pytorch
::
fused_attn_bwd
,
"Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"
);
m
.
def
(
"copy_to_kv_cache"
,
&
copy_to_kv_cache
,
"Copy new KV tokens to KV cache"
);
m
.
def
(
"convert_thd_to_bshd"
,
&
convert_thd_to_bshd
,
"Convert a tensor from THD to BSHD"
);
m
.
def
(
"convert_bshd_to_thd"
,
&
convert_bshd_to_thd
,
"Convert a tesnor from BSHD to THD"
);
m
.
def
(
"copy_to_kv_cache"
,
&
transformer_engine
::
pytorch
::
copy_to_kv_cache
,
"Copy new KV tokens to KV cache"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"convert_thd_to_bshd"
,
&
transformer_engine
::
pytorch
::
convert_thd_to_bshd
,
"Convert a tensor from THD to BSHD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"convert_bshd_to_thd"
,
&
transformer_engine
::
pytorch
::
convert_bshd_to_thd
,
"Convert a tesnor from BSHD to THD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// fused apply rope
m
.
def
(
"fused_rope_forward"
,
&
fused_rope_forward
,
"Fused Apply RoPE FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_backward"
,
&
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_forward"
,
&
transformer_engine
::
pytorch
::
fused_rope_forward
,
"Fused Apply RoPE FWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"fused_rope_backward"
,
&
transformer_engine
::
pytorch
::
fused_rope_backward
,
"Fused Apply RoPE BWD"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Misc
m
.
def
(
"get_cublasLt_version"
,
&
get_cublasLt_version
,
"G
et
cublasLt
version
"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_cudnn_version"
,
&
get_cudnn_version
,
"Get cuDNN version"
,
m
.
def
(
"get_cublasLt_version"
,
&
transformer_engine
::
pytorch
::
g
et
_
cublasLt
_
version
,
"Get cublasLt version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"get_cudnn_version"
,
&
transformer_engine
::
pytorch
::
get_cudnn_version
,
"Get cuDNN version"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
attr
(
"_num_cublas_streams"
)
=
py
::
int_
(
transformer_engine
::
num_streams
);
#ifdef USE_ROCM
...
...
@@ -246,74 +267,82 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
// Support THD format for Context Parallel
m
.
def
(
"thd_read_half_tensor"
,
&
thd_read_half_tensor
,
m
.
def
(
"thd_read_half_tensor"
,
&
transformer_engine
::
pytorch
::
thd_read_half_tensor
,
"Read the first half(half_idx=0) or the second half(half_idx=1) of each sequence in a THD "
"tensor"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"thd_second_half_lse_correction"
,
&
thd_second_half_lse_correction
,
m
.
def
(
"thd_second_half_lse_correction"
,
&
transformer_engine
::
pytorch
::
thd_second_half_lse_correction
,
"Correct the second half of the softmax_lse"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"thd_read_second_half_lse"
,
&
thd_read_second_half_lse
,
m
.
def
(
"thd_read_second_half_lse"
,
&
transformer_engine
::
pytorch
::
thd_read_second_half_lse
,
"Read the second half of the softmax_lse"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"thd_out_correction"
,
&
thd_out_correction
,
m
.
def
(
"thd_out_correction"
,
&
transformer_engine
::
pytorch
::
thd_out_correction
,
"Correct the THD format output of context parallelism in forward pass"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"thd_grad_correction"
,
&
thd_grad_correction
,
m
.
def
(
"thd_grad_correction"
,
&
transformer_engine
::
pytorch
::
thd_grad_correction
,
"Correct the THD format gradients of context parallelism in backward pass"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"thd_get_partitioned_indices"
,
&
thd_get_partitioned_indices
,
m
.
def
(
"thd_get_partitioned_indices"
,
&
transformer_engine
::
pytorch
::
thd_get_partitioned_indices
,
"Generate partitioned indices for inputs in THD format"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// nvshmem functions
m
.
def
(
"init_nvshmem_backend"
,
&
nvshmem_api
::
init_nvshmem_backend
,
m
.
def
(
"init_nvshmem_backend"
,
&
transformer_engine
::
pytorch
::
init_nvshmem_backend
,
"Initialize nvshmem backend with Pytorch distributed process groups"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"create_nvshmem_tensor"
,
&
nvshmem_api
::
create_nvshmem_tensor
,
m
.
def
(
"create_nvshmem_tensor"
,
&
transformer_engine
::
pytorch
::
create_nvshmem_tensor
,
"Create a tensor in NVSHMEM shared memory"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_send_on_current_stream"
,
&
nvshmem_api
::
nvshmem_send_on_current_stream
,
m
.
def
(
"nvshmem_send_on_current_stream"
,
&
transformer_engine
::
pytorch
::
nvshmem_send_on_current_stream
,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_wait_on_current_stream"
,
&
nvshmem_api
::
nvshmem_wait_on_current_stream
,
m
.
def
(
"nvshmem_wait_on_current_stream"
,
&
transformer_engine
::
pytorch
::
nvshmem_wait_on_current_stream
,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"nvshmem_finalize"
,
&
nvshmem_api
::
nvshmem_finalize
,
m
.
def
(
"nvshmem_finalize"
,
&
transformer_engine
::
pytorch
::
nvshmem_finalize
,
"Clean up and finalize the NVSHMEM communication backend and free associated resources"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// multi-tensor functions
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
m
.
def
(
"multi_tensor_scale"
,
&
transformer_engine
::
pytorch
::
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
m
.
def
(
"multi_tensor_l2norm"
,
&
transformer_engine
::
pytorch
::
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_unscale_l2norm"
,
&
multi_tensor_unscale_l2norm_cuda
,
m
.
def
(
"multi_tensor_unscale_l2norm"
,
&
transformer_engine
::
pytorch
::
multi_tensor_unscale_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors after unscaling (unscaling is only "
"performed for L2 norm computation, and tensors are not updated)"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
m
.
def
(
"multi_tensor_adam"
,
&
transformer_engine
::
pytorch
::
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_adam_param_remainder"
,
&
multi_tensor_adam_param_remainder_cuda
,
m
.
def
(
"multi_tensor_adam_param_remainder"
,
&
transformer_engine
::
pytorch
::
multi_tensor_adam_param_remainder_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
"where the master parameters only store the remainder bits"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_adam_fp8"
,
&
multi_tensor_adam_fp8_cuda
,
m
.
def
(
"multi_tensor_adam_fp8"
,
&
transformer_engine
::
pytorch
::
multi_tensor_adam_fp8_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_adam_capturable"
,
&
multi_tensor_adam_capturable_cuda
,
m
.
def
(
"multi_tensor_adam_capturable"
,
&
transformer_engine
::
pytorch
::
multi_tensor_adam_capturable_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support and LR scheduling"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_adam_capturable_master"
,
&
multi_tensor_adam_capturable_master_cuda
,
m
.
def
(
"multi_tensor_adam_capturable_master"
,
&
transformer_engine
::
pytorch
::
multi_tensor_adam_capturable_master_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer with CUDA graph "
"support, LR scheduling and FP32 master weights"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_sgd"
,
&
multi_tensor_sgd_cuda
,
m
.
def
(
"multi_tensor_sgd"
,
&
transformer_engine
::
pytorch
::
multi_tensor_sgd_cuda
,
"Fused SGD optimizer for list of contiguous tensors"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
m
.
def
(
"multi_tensor_compute_scale_and_scale_inv"
,
&
multi_tensor_compute_scale_and_scale_inv_cuda
,
m
.
def
(
"multi_tensor_compute_scale_and_scale_inv"
,
&
transformer_engine
::
pytorch
::
multi_tensor_compute_scale_and_scale_inv_cuda
,
"Fused compute scale and scale_inv from amax"
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
// Data structures
...
...
@@ -359,10 +388,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"num_comm_sm"
)
=
16
,
py
::
arg
(
"set_sm_margin"
)
=
true
,
py
::
arg
(
"atomic_gemm"
)
=
false
,
py
::
arg
(
"rs_overlap_first_gemm"
)
=
false
)
.
def
(
"copy_into_buffer"
,
&
CommOverlap
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlap
::
get_buffer
,
py
::
arg
(
"quantizer"
),
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"set_buffer_params"
,
&
CommOverlap
::
set_buffer_params
);
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlap
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
);
py
::
class_
<
CommOverlapP2P
,
std
::
shared_ptr
<
CommOverlapP2P
>
,
transformer_engine
::
CommOverlapP2PBase
,
transformer_engine
::
CommOverlapCore
>
(
...
...
@@ -377,8 +405,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py
::
arg
(
"set_sm_margin"
)
=
false
,
py
::
arg
(
"atomic_gemm"
)
=
false
,
py
::
arg
(
"use_ce"
)
=
true
,
py
::
arg
(
"aggregate"
)
=
false
)
.
def
(
"copy_into_buffer"
,
&
CommOverlapP2P
::
copy_into_buffer
,
py
::
arg
(
"input"
),
py
::
arg
(
"quantizer"
),
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlapP2P
::
get_buffer
,
py
::
arg
(
"quantizer"
),
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
)
.
def
(
"set_buffer_params"
,
&
CommOverlapP2P
::
set_buffer_params
);
py
::
arg
(
"local_chunk"
)
=
false
)
.
def
(
"get_buffer"
,
&
CommOverlapP2P
::
get_buffer
,
py
::
arg
(
"local_chunk"
)
=
false
,
py
::
arg
(
"shape"
)
=
std
::
nullopt
);
}
transformer_engine/pytorch/csrc/extensions/recipe.cpp
View file @
f8c2af4c
...
...
@@ -12,10 +12,9 @@
#include "common/common.h"
#include "extensions.h"
void
compute_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
&
amax
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
namespace
transformer_engine
::
pytorch
{
void
compute_amax
(
const
at
::
Tensor
&
tensor
,
at
::
Tensor
&
amax
)
{
auto
input_tensor
=
tensor
.
contiguous
();
const
TensorWrapper
&
te_input
=
makeTransformerEngineTensor
(
input_tensor
);
...
...
@@ -23,7 +22,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
TORCH_CHECK
(
amax
.
numel
()
==
1
,
"amax must have exactly one element"
);
TensorWrapper
fake_te_output
(
nullptr
,
te_input
.
shape
(),
transformer_engine
::
DType
::
kFloat8E4M3
,
// It doesn't matter because we only compute amax.
DType
::
kFloat8E4M3
,
// It doesn't matter because we only compute amax.
amax
.
data_ptr
<
float
>
());
nvte_compute_amax
(
te_input
.
data
(),
fake_te_output
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
...
...
@@ -33,10 +32,7 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
std
::
vector
<
at
::
Tensor
>
amax_histories
,
std
::
vector
<
at
::
Tensor
>
scales
,
const
std
::
string
&
amax_compute_algo
,
transformer_engine
::
DType
fp8_dtype
,
float
margin
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
::
pytorch
;
DType
fp8_dtype
,
float
margin
)
{
size_t
num_tensors
=
amax_histories
.
size
();
std
::
vector
<
Tensor
>
t_amax_histories
(
num_tensors
);
std
::
vector
<
Tensor
>
t_scales
(
num_tensors
);
...
...
@@ -63,3 +59,5 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reductio
amax_compute_algo
.
c_str
(),
static_cast
<
NVTEDType
>
(
fp8_dtype
),
margin
,
at
::
cuda
::
getCurrentCUDAStream
());
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/softmax.cpp
View file @
f8c2af4c
...
...
@@ -6,8 +6,9 @@
#include "extensions.h"
namespace
transformer_engine
::
pytorch
{
at
::
Tensor
scaled_softmax_forward
(
at
::
Tensor
input
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
...
...
@@ -38,8 +39,6 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) {
at
::
Tensor
scaled_softmax_backward
(
at
::
Tensor
output_grad_
,
at
::
Tensor
softmax_results_
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
auto
output_grads
=
output_grad_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
...
...
@@ -65,8 +64,6 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r
}
at
::
Tensor
scaled_masked_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
...
...
@@ -105,8 +102,6 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa
at
::
Tensor
scaled_masked_softmax_backward
(
at
::
Tensor
output_grad_
,
at
::
Tensor
softmax_results_
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
auto
output_grads
=
output_grad_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
...
...
@@ -132,8 +127,6 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so
}
at
::
Tensor
scaled_upper_triang_masked_softmax_forward
(
at
::
Tensor
input
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
...
...
@@ -159,8 +152,6 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc
at
::
Tensor
scaled_upper_triang_masked_softmax_backward
(
at
::
Tensor
output_grads_
,
at
::
Tensor
softmax_results_
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
...
...
@@ -188,7 +179,6 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
}
at
::
Tensor
scaled_aligned_causal_masked_softmax_forward
(
at
::
Tensor
input
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
...
...
@@ -220,8 +210,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float
at
::
Tensor
scaled_aligned_causal_masked_softmax_backward
(
at
::
Tensor
output_grad_
,
at
::
Tensor
softmax_results_
,
float
scale_factor
)
{
using
namespace
transformer_engine
::
pytorch
;
auto
output_grads
=
output_grad_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
...
...
@@ -245,3 +233,5 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_
return
output_grads
;
}
}
// namespace transformer_engine::pytorch
transformer_engine/pytorch/csrc/extensions/transpose.cpp
View file @
f8c2af4c
...
...
@@ -13,13 +13,12 @@ namespace transformer_engine::pytorch {
std
::
vector
<
py
::
object
>
fused_multi_quantize
(
std
::
vector
<
at
::
Tensor
>
input_list
,
std
::
optional
<
std
::
vector
<
py
::
object
>>
output_list
,
std
::
vector
<
py
::
handle
>
quantizer_list
,
transformer_engine
::
DType
otype
)
{
std
::
vector
<
py
::
handle
>
quantizer_list
,
DType
otype
)
{
init_extension
();
std
::
vector
<
NVTETensor
>
nvte_tensor_input_list
;
std
::
vector
<
NVTETensor
>
nvte_tensor_output_list
;
std
::
vector
<
py
::
object
>
py_output_objects_list
;
std
::
vector
<
transformer_engine
::
TensorWrapper
>
tensor_wrappers
;
std
::
vector
<
TensorWrapper
>
tensor_wrappers
;
if
(
output_list
.
has_value
())
{
py_output_objects_list
=
output_list
.
value
();
}
...
...
@@ -33,7 +32,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
auto
input_tensor
=
makeTransformerEngineTensor
(
input_list
[
i
]);
const
NVTEShape
input_shape
=
input_tensor
.
shape
();
transformer_engine
::
TensorWrapper
output_tensor
;
TensorWrapper
output_tensor
;
if
(
!
detail
::
IsFloat8Quantizers
(
quantizer_list
[
i
].
ptr
()))
{
with_fused_kernel
=
false
;
...
...
@@ -68,8 +67,10 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
// Launch TE kernel
if
(
with_fused_kernel
)
{
NVTE_SCOPED_GIL_RELEASE
({
nvte_multi_cast_transpose
(
nvte_tensor_input_list
.
size
(),
nvte_tensor_input_list
.
data
(),
nvte_tensor_output_list
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
});
}
else
{
for
(
size_t
i
=
0
;
i
<
py_output_objects_list
.
size
();
i
++
)
{
quantize
(
input_list
[
i
],
quantizer_list
[
i
],
py_output_objects_list
[
i
],
std
::
nullopt
);
...
...
@@ -78,8 +79,7 @@ std::vector<py::object> fused_multi_quantize(std::vector<at::Tensor> input_list,
return
py_output_objects_list
;
}
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
transformer_engine
::
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
)
{
at
::
Tensor
fp8_transpose
(
at
::
Tensor
input
,
DType
otype
,
std
::
optional
<
at
::
Tensor
>
output
)
{
init_extension
();
const
auto
dim
=
input
.
dim
();
...
...
@@ -100,8 +100,8 @@ at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
}
if
(
M
==
0
||
N
==
0
)
return
out
;
auto
input_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
{
M
,
N
},
otype
);
auto
output_cu
=
makeTransformerEngineTensor
(
out
.
data_ptr
(),
{
N
,
M
},
otype
);
auto
input_cu
=
makeTransformerEngineTensor
(
input
.
data_ptr
(),
std
::
vector
<
size_t
>
{
M
,
N
},
otype
);
auto
output_cu
=
makeTransformerEngineTensor
(
out
.
data_ptr
(),
std
::
vector
<
size_t
>
{
N
,
M
},
otype
);
nvte_transpose
(
input_cu
.
data
(),
output_cu
.
data
(),
at
::
cuda
::
getCurrentCUDAStream
());
...
...
transformer_engine/pytorch/csrc/pybind.h
View file @
f8c2af4c
...
...
@@ -8,6 +8,8 @@
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_
#include <Python.h>
#include <pybind11/detail/common.h>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
...
...
@@ -18,6 +20,16 @@
namespace
transformer_engine
::
pytorch
{
#define NVTE_SCOPED_GIL_RELEASE(code_block) \
do { \
if (PyGILState_Check()) { \
pybind11::gil_scoped_release _gil_release; \
code_block \
} else { \
code_block \
} \
} while (false);
extern
PyTypeObject
*
Float8TensorPythonClass
;
extern
PyTypeObject
*
Float8TensorBasePythonClass
;
extern
PyTypeObject
*
Float8QuantizerClass
;
...
...
transformer_engine/pytorch/csrc/
extensions/
quantizer.cpp
→
transformer_engine/pytorch/csrc/quantizer.cpp
View file @
f8c2af4c
...
...
@@ -9,7 +9,6 @@
#include "common.h"
#include "pybind.h"
#include "torch/torch.h"
#include "util.h"
namespace
transformer_engine
::
pytorch
{
...
...
@@ -103,7 +102,7 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
}
const
py
::
object
py_data
=
rowwise_usage
?
py
::
cast
(
data
)
:
py
::
none
();
at
::
Tensor
columnwise_data
;
bool
create_transpose
=
columnwise_usage
&&
!
non_tn_fp8_gemm_supported
();
bool
create_transpose
=
columnwise_usage
&&
!
nvte_is_
non_tn_fp8_gemm_supported
();
if
(
create_transpose
)
{
columnwise_data
=
at
::
empty
(
columnwise_torch_shape
,
opts
);
}
...
...
@@ -215,7 +214,7 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
}
const
py
::
object
py_data
=
rowwise_usage
?
py
::
cast
(
data
)
:
py
::
none
();
at
::
Tensor
columnwise_data
;
bool
create_transpose
=
columnwise_usage
&&
!
non_tn_fp8_gemm_supported
();
bool
create_transpose
=
columnwise_usage
&&
!
nvte_is_
non_tn_fp8_gemm_supported
();
if
(
create_transpose
)
{
columnwise_data
=
at
::
empty
(
columnwise_torch_shape
,
opts
);
}
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment