Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
44740c6c
Commit
44740c6c
authored
Jul 22, 2025
by
yuguo
Browse files
Merge commit '
7a9a0825
' of...
Merge commit '
7a9a0825
' of
https://github.com/NVIDIA/TransformerEngine
parents
8113d9e0
7a9a0825
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1851 additions
and
250 deletions
+1851
-250
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+2
-0
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+13
-10
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+14
-1
transformer_engine/common/util/cuda_runtime.h
transformer_engine/common/util/cuda_runtime.h
+6
-0
transformer_engine/common/util/multi_stream.cpp
transformer_engine/common/util/multi_stream.cpp
+8
-0
transformer_engine/common/util/padding.cu
transformer_engine/common/util/padding.cu
+163
-0
transformer_engine/common/util/rtc.cpp
transformer_engine/common/util/rtc.cpp
+1
-1
transformer_engine/debug/features/utils/stats_buffer.py
transformer_engine/debug/features/utils/stats_buffer.py
+7
-0
transformer_engine/debug/features/utils/stats_computation.py
transformer_engine/debug/features/utils/stats_computation.py
+3
-1
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+6
-0
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+68
-50
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+1094
-88
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+7
-2
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+17
-6
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+49
-36
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+3
-0
transformer_engine/jax/csrc/extensions/ffi.cpp
transformer_engine/jax/csrc/extensions/ffi.cpp
+3
-4
transformer_engine/jax/csrc/extensions/gemm.cpp
transformer_engine/jax/csrc/extensions/gemm.cpp
+372
-51
transformer_engine/jax/csrc/extensions/misc.h
transformer_engine/jax/csrc/extensions/misc.h
+9
-0
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+6
-0
No files found.
transformer_engine/common/transformer_engine.cpp
View file @
44740c6c
...
...
@@ -46,6 +46,8 @@ std::string to_string(const DType type) {
return
"Float8E8M0"
;
case
DType
::
kFloat4E2M1
:
return
"Float4E2M1"
;
case
DType
::
kInt16
:
return
"Int16"
;
case
DType
::
kInt32
:
return
"Int32"
;
case
DType
::
kInt64
:
...
...
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
44740c6c
...
...
@@ -936,17 +936,20 @@ template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void
cast_gated
(
const
Tensor
&
input
,
Tensor
*
output
,
cudaStream_t
stream
)
{
CheckInputTensor
(
input
,
"gated_act_input"
);
CheckOutputTensor
(
*
output
,
"gated_act_output"
);
NVTE_CHECK
(
input
.
data
.
shape
.
size
()
==
2
,
"Input must have 2 dimensions."
);
NVTE_CHECK
(
output
->
data
.
shape
.
size
()
==
2
,
"Output must have 2 dimensions."
);
NVTE_CHECK
(
input
.
data
.
shape
[
0
]
==
output
->
data
.
shape
[
0
],
"Input shape[0] must be equal to output shape[0]."
);
NVTE_CHECK
(
input
.
data
.
shape
[
1
]
==
output
->
data
.
shape
[
1
]
*
2
,
"Input shape[1] must be 2x larger than output shape[1]."
);
NVTE_CHECK
(
output
->
flat_first_dim
()
==
input
.
flat_first_dim
(),
"Wrong output shape. Expected (after flattening) ["
,
input
.
flat_first_dim
(),
", *], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
input
.
flat_last_dim
()
%
2
==
0
,
"Wrong input shape. Expected (after flattening) last dimension to be even, "
,
"got ["
,
input
.
flat_first_dim
(),
", "
,
input
.
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
output
->
flat_last_dim
()
==
input
.
flat_last_dim
()
/
2
,
"Wrong output shape. Expected (after flattening) [*, "
,
input
.
flat_last_dim
()
/
2
,
"], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT
(
input
.
data
.
dtype
,
IType
,
input
.
dtype
()
,
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT
(
output
->
data
.
dtype
,
OType
,
output
->
dtype
()
,
OType
,
if
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
||
is_delayed_tensor_scaling
(
output
->
scaling_mode
))
{
...
...
@@ -956,8 +959,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast
<
OType
*>
(
output
->
data
.
dptr
),
reinterpret_cast
<
const
fp32
*>
(
output
->
scale
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
amax
.
dptr
),
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
output
->
data
.
shape
[
0
]
,
output
->
data
.
shape
[
1
]
,
{},
stream
);
reinterpret_cast
<
fp32
*>
(
output
->
scale_inv
.
dptr
),
input
.
flat_first_dim
()
,
output
->
flat_last_dim
()
,
{},
stream
);
}
else
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output
->
scaling_mode
)
+
"."
);
});
// NOLINT(*)
...
...
transformer_engine/common/util/cuda_runtime.cpp
View file @
44740c6c
...
...
@@ -123,8 +123,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id)
bool
supports_multicast
(
int
device_id
)
{
#if CUDART_VERSION >= 12010
// NOTE: This needs to be guarded at compile
time because the
// NOTE: This needs to be guarded at compile
-time and run-
time because the
// CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED enum is not defined in earlier CUDA versions.
if
(
cudart_version
()
<
12010
)
{
return
false
;
}
static
std
::
vector
<
bool
>
cache
(
num_devices
(),
false
);
static
std
::
vector
<
std
::
once_flag
>
flags
(
num_devices
());
if
(
device_id
<
0
)
{
...
...
@@ -219,6 +222,16 @@ const std::string &include_directory(bool required) {
}
#endif // __HIP_PLATFORM_AMD__
int
cudart_version
()
{
auto
get_version
=
[]()
->
int
{
int
version
;
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
version
));
return
version
;
};
static
int
version
=
get_version
();
return
version
;
}
}
// namespace cuda
}
// namespace transformer_engine
transformer_engine/common/util/cuda_runtime.h
View file @
44740c6c
...
...
@@ -79,6 +79,12 @@ bool supports_multicast(int device_id = -1);
const
std
::
string
&
include_directory
(
bool
required
=
false
);
#endif
/* \brief CUDA Runtime version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
int
cudart_version
();
}
// namespace cuda
}
// namespace transformer_engine
...
...
transformer_engine/common/util/multi_stream.cpp
View file @
44740c6c
...
...
@@ -113,4 +113,12 @@ int get_num_compute_streams() {
int
nvte_get_num_compute_streams
()
{
return
transformer_engine
::
detail
::
get_num_compute_streams
();
}
cudaStream_t
nvte_get_compute_stream
(
const
int
idx
)
{
return
transformer_engine
::
detail
::
get_compute_stream
(
idx
);
}
cudaEvent_t
nvte_get_compute_stream_event
(
const
int
idx
)
{
return
transformer_engine
::
detail
::
get_compute_stream_event
(
idx
);
}
#endif // TRANSFORMER_ENGINE_UTIL_MULTI_STREAM_H_
transformer_engine/common/util/padding.cu
View file @
44740c6c
...
...
@@ -126,6 +126,83 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
}
}
template
<
int
nvec
,
typename
Type
>
__global__
void
__launch_bounds__
(
threads_per_block
)
multi_unpadding_kernel
(
MultiPaddingArgs
args
)
{
using
Vec
=
Vec
<
Type
,
nvec
>
;
// Thread indices
// Note: Block is interpreted as a warp_size x num_warps grid
constexpr
int
bdimx
=
THREADS_PER_WARP
;
constexpr
int
bdimy
=
n_warps_per_tile
;
const
int
tid
=
threadIdx
.
x
;
const
int
tidx
=
tid
%
bdimx
;
const
int
tidy
=
tid
/
bdimx
;
const
int
bid
=
blockIdx
.
x
;
// Input tensors are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
constexpr
int
tile_dim_m
=
THREADS_PER_WARP
*
nvec
;
constexpr
int
tile_dim_n
=
THREADS_PER_WARP
*
nvec
;
// Number of nvec x nvec subtiles for each thread to
// load/store
constexpr
int
n_iterations
=
THREADS_PER_WARP
/
n_warps_per_tile
;
// Find tensor corresponding to block
int
tensor_id
=
0
;
while
(
args
.
block_range
[
tensor_id
+
1
]
<=
bid
)
{
++
tensor_id
;
}
const
Type
*
input
=
reinterpret_cast
<
const
Type
*>
(
args
.
input_list
[
tensor_id
]);
Type
*
output
=
reinterpret_cast
<
Type
*>
(
args
.
output_list
[
tensor_id
]);
const
int
num_rows
=
args
.
num_rows_list
[
tensor_id
];
const
int
row_length
=
args
.
row_length_list
[
tensor_id
];
// Find position of tile within tensor
const
int
num_tiles_n
=
(
row_length
+
tile_dim_n
-
1
)
/
tile_dim_n
;
const
int
tile_id
=
bid
-
args
.
block_range
[
tensor_id
];
const
int
tile_id_m
=
tile_id
/
num_tiles_n
;
const
int
tile_id_n
=
tile_id
%
num_tiles_n
;
const
int
tile_row
=
tile_id_m
*
tile_dim_m
;
const
int
tile_col
=
tile_id_n
*
tile_dim_n
;
// Load input and store to registers
// Note: Each thread loads n_iterations subtiles, casts to output
// type, and transposes in registers.
Type
local_zero
=
static_cast
<
Type
>
(
0.
f
);
#pragma unroll
for
(
int
iter
=
0
;
iter
<
n_iterations
;
++
iter
)
{
const
int
i1
=
tidy
+
iter
*
bdimy
;
const
int
j1
=
tidx
;
#pragma unroll
for
(
int
i2
=
0
;
i2
<
nvec
;
++
i2
)
{
const
int
row
=
tile_row
+
i1
*
nvec
+
i2
;
const
int
col
=
tile_col
+
j1
*
nvec
;
Vec
local_input
;
Vec
local_output
;
local_input
.
clear
();
if
(
row
<
num_rows
)
{
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
local_input
.
data
.
elt
[
j2
]
=
input
[
row
*
row_length
+
col
+
j2
];
}
}
}
#pragma unroll
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
local_output
.
data
.
elt
[
j2
]
=
local_input
.
data
.
elt
[
j2
];
}
if
(
row
<
num_rows
)
{
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
output
[
row
*
row_length
+
col
+
j2
]
=
local_output
.
data
.
elt
[
j2
];
}
}
}
}
}
}
}
// namespace
void
multi_padding
(
const
std
::
vector
<
Tensor
*>
input_list
,
std
::
vector
<
Tensor
*>
output_list
,
...
...
@@ -202,6 +279,78 @@ void multi_padding(const std::vector<Tensor*> input_list, std::vector<Tensor*> o
}
}
void
multi_unpadding
(
const
std
::
vector
<
Tensor
*>
input_list
,
std
::
vector
<
Tensor
*>
output_list
,
const
std
::
vector
<
int
>
unpadded_num_rows_list
,
cudaStream_t
stream
)
{
// Check that number of tensors is valid
NVTE_CHECK
(
output_list
.
size
()
==
input_list
.
size
(),
"Number of input and output tensors must match"
);
if
(
input_list
.
empty
())
{
return
;
}
// Check that tensor properties are valid
DType
type
=
input_list
[
0
]
->
data
.
dtype
;
for
(
size_t
tensor_id
=
0
;
tensor_id
<
input_list
.
size
();
++
tensor_id
)
{
const
auto
&
input
=
*
input_list
[
tensor_id
];
const
auto
&
output
=
*
output_list
[
tensor_id
];
CheckInputTensor
(
input
,
"multi_unpadding_input_"
+
std
::
to_string
(
tensor_id
));
CheckInputTensor
(
output
,
"multi_unpadding_output_"
+
std
::
to_string
(
tensor_id
));
NVTE_CHECK
(
input
.
data
.
dtype
==
type
,
"Input tensor types do not match."
);
NVTE_CHECK
(
output
.
data
.
dtype
==
type
,
"Output tensor types do not match."
);
NVTE_CHECK
(
input
.
data
.
shape
.
size
()
==
2
,
"Input tensor must have 2 dimensions."
);
NVTE_CHECK
(
output
.
data
.
shape
[
0
]
==
unpadded_num_rows_list
[
tensor_id
],
"output tensor shape does not match padded input shape."
);
}
// Input matrices are divided into tiles
// Note: Each tile is a warp_size x warp_size grid of nvec x nvec subtiles
const
int
tile_dim_m
=
THREADS_PER_WARP
*
desired_load_store_size
/
typeToSize
(
type
);
const
int
tile_dim_n
=
THREADS_PER_WARP
*
desired_load_store_size
/
typeToSize
(
type
);
// Add tensors to kernel argument struct
MultiPaddingArgs
kernel_args
;
kernel_args
.
num_tensors
=
0
;
kernel_args
.
block_range
[
0
]
=
0
;
for
(
size_t
tensor_id
=
0
;
tensor_id
<
input_list
.
size
();
++
tensor_id
)
{
// Launch kernel if argument struct is full
if
(
kernel_args
.
num_tensors
==
kMaxTensorsPerKernel
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
type
,
Type
,
constexpr
int
nvec
=
desired_load_store_size
/
sizeof
(
Type
);
const
int
n_blocks
=
kernel_args
.
block_range
[
kernel_args
.
num_tensors
];
multi_unpadding_kernel
<
nvec
,
Type
>
<<<
n_blocks
,
threads_per_block
,
0
,
stream
>>>
(
kernel_args
););
// NOLINT(*)
kernel_args
.
num_tensors
=
0
;
}
// Calculate number of thread blocks needed for tensor
const
int
num_rows
=
unpadded_num_rows_list
[
tensor_id
];
const
int
row_length
=
input_list
[
tensor_id
]
->
data
.
shape
[
1
];
const
int
num_tiles_m
=
(
num_rows
+
tile_dim_m
-
1
)
/
tile_dim_m
;
const
int
num_tiles_n
=
(
row_length
+
tile_dim_n
-
1
)
/
tile_dim_n
;
const
int
num_tiles
=
num_tiles_m
*
num_tiles_n
;
// Add tensor to kernel argument struct
const
int
pos
=
kernel_args
.
num_tensors
;
kernel_args
.
input_list
[
pos
]
=
const_cast
<
void
*>
(
input_list
[
tensor_id
]
->
data
.
dptr
);
kernel_args
.
output_list
[
pos
]
=
output_list
[
tensor_id
]
->
data
.
dptr
;
kernel_args
.
num_rows_list
[
pos
]
=
num_rows
;
kernel_args
.
row_length_list
[
pos
]
=
row_length
;
kernel_args
.
block_range
[
pos
+
1
]
=
kernel_args
.
block_range
[
pos
]
+
num_tiles
;
kernel_args
.
num_tensors
++
;
}
// Launch kernel
if
(
kernel_args
.
num_tensors
>
0
)
{
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
type
,
Type
,
constexpr
int
nvec
=
desired_load_store_size
/
sizeof
(
Type
);
const
int
n_blocks
=
kernel_args
.
block_range
[
kernel_args
.
num_tensors
];
multi_unpadding_kernel
<
nvec
,
Type
>
<<<
n_blocks
,
threads_per_block
,
0
,
stream
>>>
(
kernel_args
););
// NOLINT(*)
}
}
}
// namespace transformer_engine
void
nvte_multi_padding
(
size_t
num_tensors
,
const
NVTETensor
*
input_list
,
NVTETensor
*
output_list
,
...
...
@@ -217,3 +366,17 @@ void nvte_multi_padding(size_t num_tensors, const NVTETensor* input_list, NVTETe
}
multi_padding
(
input_list_
,
output_list_
,
padded_num_rows_list_
,
stream
);
}
void
nvte_multi_unpadding
(
size_t
num_tensors
,
const
NVTETensor
*
input_list
,
NVTETensor
*
output_list
,
const
int
*
unpadded_num_rows_list
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_multi_unpadding
);
using
namespace
transformer_engine
;
std
::
vector
<
Tensor
*>
input_list_
,
output_list_
;
std
::
vector
<
int
>
unpadded_num_rows_list_
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
input_list_
.
push_back
(
convertNVTETensorCheck
(
input_list
[
i
]));
output_list_
.
push_back
(
convertNVTETensorCheck
(
output_list
[
i
]));
unpadded_num_rows_list_
.
push_back
(
unpadded_num_rows_list
[
i
]);
}
multi_unpadding
(
input_list_
,
output_list_
,
unpadded_num_rows_list_
,
stream
);
}
transformer_engine/common/util/rtc.cpp
View file @
44740c6c
...
...
@@ -156,7 +156,7 @@ void KernelManager::compile(const std::string& kernel_label, const std::string&
#ifndef USE_ROCM
const
int
sm_arch_
=
cuda
::
sm_arch
(
device_id
);
const
int
compile_sm_arch
=
std
::
min
(
sm_arch_
,
max_supported_sm_arch
());
const
bool
compile_ptx
=
(
CUDA_VERSION
<=
11000
)
||
(
sm_arch_
!=
compile_sm_arch
)
;
const
bool
compile_ptx
=
sm_arch_
!=
compile_sm_arch
;
#endif // USE_ROCM
// Compilation flags
...
...
transformer_engine/debug/features/utils/stats_buffer.py
View file @
44740c6c
...
...
@@ -85,6 +85,13 @@ class _Buffer:
if
self
.
modified
[
0
]
and
not
self
.
reduce_within_microbatch
:
return
if
(
tensor
.
numel
()
==
0
if
hasattr
(
tensor
,
"numel"
)
else
all
((
t
is
None
or
t
.
numel
()
==
0
)
for
t
in
tensor
.
get_data_tensors
())
):
return
# save stats for tensor to tmp buffer
for
stat_name
in
self
.
stats_to_compute
:
fn
,
_
=
STATS
[
stat_name
]
...
...
transformer_engine/debug/features/utils/stats_computation.py
View file @
44740c6c
...
...
@@ -17,6 +17,8 @@ def _compute_dynamic_range_top(tensor):
"""Computes the log2 of the amax of the tensor"""
tensor_abs
=
tensor
.
abs
()
tensor_abs
=
tensor_abs
[
tensor_abs
!=
0
]
if
tensor_abs
.
numel
()
==
0
:
return
torch
.
inf
amax
=
tensor_abs
.
max
().
float
()
if
not
amax
.
all
():
amax
=
torch
.
tensor
(
1
,
device
=
tensor
.
device
).
to
(
torch
.
float
)
...
...
@@ -125,7 +127,7 @@ STATS = {
lambda
buffers
:
min
(
_get
(
buffers
,
"dynamic_range_bottom"
)),
),
"underflows_num"
:
(
lambda
x
:
(
x
.
_data
==
0
).
sum
(),
lambda
x
:
(
x
.
get
_data
_tensors
()[
0
]
==
0
).
sum
(),
lambda
buffers
:
sum
(
_get
(
buffers
,
"underflows_num"
)),
),
"std"
:
(
...
...
transformer_engine/debug/pytorch/debug_quantization.py
View file @
44740c6c
...
...
@@ -62,6 +62,12 @@ class DebugQuantizer(Quantizer):
self
.
tp_group
=
tp_group
# used in inspect_tensor calls
self
.
iteration
=
debug_api
.
DEBUG_MANAGER
.
_trainer_iteration_count
# .internal = True is slightly faster, but results
# in errors when caching the weights.
# Setting .internal = False is safer.
if
parent_quantizer
is
not
None
:
parent_quantizer
.
internal
=
False
self
.
rowwise_gemm_name
,
self
.
columnwise_gemm_name
=
_tensor_to_gemm_names_map
[
tensor_name
]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
44740c6c
...
...
@@ -415,37 +415,35 @@ class ActLuPrimitive(BasePrimitive):
result_types
,
):
del
out_dtype
,
act_enum
,
act_len
,
scale_dtype
,
is_outer
,
mesh
,
result_types
prefix
=
"ActLuPrimitive_"
x_rank
=
len
(
value_types
[
0
].
shape
)
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
x_rank
-
1
,
unique_var
=
"ActLuPrimitive_i
"
,
flatten_axis
=-
2
x_rank
-
1
,
unique_var
=
prefix
+
"x
"
,
flatten_axis
=-
2
)
x_axes
=
scale_rules
.
input_spec
+
(
f
"x
{
x_rank
-
1
}
"
,)
x_axes
=
scale_rules
.
input_spec
+
(
prefix
+
f
"x
{
x_rank
-
1
}
"
,)
out
=
(
*
x_axes
[:
-
2
],
x_axes
[
-
1
])
scale_inv
=
scale_rules
.
rowwise_rule
colwise_scale_inv
=
scale_rules
.
colwise_rule
colwise_out
=
(
prefix
+
"out_colwise"
,)
colwise_scale_inv
=
(
prefix
+
"scale_inv_colwise"
,)
if
is_2x
:
colwise_scale_inv
=
scale_rules
.
colwise_rule
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out
=
tuple
(
multidim_transpose
(
x_axes
,
static_axis_boundary
=-
1
,
transpose_axis
=-
2
)
)
else
:
colwise_out
=
out
else
:
colwise_out
=
(
"j"
,)
colwise_scale_inv
=
(
"k"
,)
# amax is always a unit tensor.
amax
=
(
"l
"
,)
amax
=
(
prefix
+
"amax
"
,)
return
SdyShardingRule
(
(
x_axes
,
"…1"
,
(
"…1"
,
),
),
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
amax
),
**
scale_rules
.
factor_sizes
,
)
...
...
@@ -890,28 +888,26 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types
,
):
del
out_dtype
,
scale_dtype
,
act_enum
,
act_len
,
is_outer
,
mesh
,
result_types
x_rank
=
len
(
value_types
[
1
].
shape
)
prefix
=
"BaseDActLuDBiasQuantizePrimitive_"
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
x_rank
,
unique_var
=
"BaseDActLuDBiasQuantizePrimitive_i
"
,
flatten_axis
=-
2
len
(
value_types
[
1
].
shape
),
unique_var
=
prefix
+
"x
"
,
flatten_axis
=-
2
)
x_axes
=
scale_rules
.
input_spec
dz_axes
=
(
*
x_axes
[:
-
2
],
x_axes
[
-
1
])
out
=
x_axes
colwise_out
=
(
prefix
+
"out_colwise"
,)
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out
=
tuple
(
multidim_transpose
(
x_axes
,
transpose_axis
=-
2
))
else
:
colwise_out
=
tuple
(
x_axes
)
else
:
colwise_out
=
(
"j"
,)
colwise_out
=
out
dbias
=
x_axes
[
-
2
:]
if
is_dbias
else
(
"k
"
,)
amax
=
(
"…4
"
,)
dbias
=
x_axes
[
-
2
:]
if
is_dbias
else
(
prefix
+
"dbias
"
,)
amax
=
(
prefix
+
"amax
"
,)
return
SdyShardingRule
(
(
(
"…0"
,),
tuple
(
x_axes
)
,
(
"…2"
,)),
(
dz_axes
,
x_axes
,
(
"…2"
,)),
(
out
,
colwise_out
,
scale_rules
.
rowwise_rule
,
scale_rules
.
colwise_rule
,
amax
,
dbias
),
**
scale_rules
.
factor_sizes
,
)
...
...
@@ -985,6 +981,7 @@ def act_lu(
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
quantizer
:
Optional
[
Quantizer
]
=
None
,
noop_scaled_tensor
:
bool
=
False
,
)
->
Union
[
jnp
.
ndarray
,
ScaledTensor
]:
"""Activation with optional quantization.
...
...
@@ -993,6 +990,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
If quantizer is None:
...
...
@@ -1037,6 +1035,10 @@ def act_lu(
is_outer
=
True
,
)
out
=
out
.
reshape
(
output_shape
)
if
noop_scaled_tensor
:
return
ScaledTensorFactory
.
create_2x
(
out
,
None
,
out
,
None
,
ScalingMode
.
NO_SCALING
,
dq_dtype
=
out
.
dtype
)
return
out
if
quantizer
.
scaling_mode
==
ScalingMode
.
CURRENT_TENSOR_SCALING
:
...
...
@@ -1090,6 +1092,7 @@ def quantize_dact_dbias(
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
is_dbias
:
bool
=
True
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
noop_scaled_tensor
:
bool
=
False
,
)
->
Tuple
[
ScaledTensor
,
jnp
.
ndarray
]:
"""Compute gradients of activation and bias with optional quantization.
...
...
@@ -1100,6 +1103,7 @@ def quantize_dact_dbias(
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
...
...
@@ -1113,13 +1117,49 @@ def quantize_dact_dbias(
f
"
{
x
.
shape
}
and act_len
{
act_len
}
"
)
scale
=
jnp
.
empty
((),
jnp
.
float32
)
act_type_id
=
ActivationEnum
[
activation_type
]
PrimitiveClass
=
DActLuDBiasQuantizePrimitive
if
is_dbias
else
DActLuQuantizePrimitive
if
not
PrimitiveClass
.
enabled
():
if
not
PrimitiveClass
.
enabled
()
or
(
quantizer
is
not
None
and
quantizer
.
q_layout
==
QuantizeLayout
.
COLWISE
):
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
# TE/common does not support colwise-only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_layout
==
QuantizeLayout
.
COLWISE
:
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
if
quantizer
is
None
:
output
,
_
,
_
,
_
,
_
,
_
=
PrimitiveClass
.
outer_primitive
.
bind
(
dz
,
x
,
scale
,
# outputs float32 for dbias accumulation
out_dtype
=
(
jnp
.
float32
if
is_dbias
else
x
.
dtype
),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode
=
ScalingMode
.
NO_SCALING
.
value
,
is_2x
=
False
,
# unused
scale_dtype
=
jnp
.
float32
,
# unused
is_dbias
=
False
,
act_enum
=
act_type_id
,
act_len
=
act_len
,
is_outer
=
True
,
)
output
=
output
.
astype
(
x
.
dtype
)
dbias
=
None
if
is_dbias
:
dbias
=
_jax_dbias
(
output
,
dtype
=
x
.
dtype
,
flatten_axis
=-
2
)
if
noop_scaled_tensor
:
return
(
ScaledTensorFactory
.
create_2x
(
output
,
None
,
output
,
None
,
ScalingMode
.
NO_SCALING
,
dq_dtype
=
output
.
dtype
,
),
dbias
,
)
return
output
,
dbias
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
...
...
@@ -1145,31 +1185,6 @@ def quantize_dact_dbias(
if
war_output
is
not
None
:
return
war_output
scale
=
jnp
.
empty
((),
jnp
.
float32
)
act_type_id
=
ActivationEnum
[
activation_type
]
if
quantizer
is
None
:
output
,
_
,
_
,
_
,
_
,
_
=
PrimitiveClass
.
outer_primitive
.
bind
(
dz
,
x
,
scale
,
# outputs float32 for dbias accumulation
out_dtype
=
(
jnp
.
float32
if
is_dbias
else
x
.
dtype
),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode
=
ScalingMode
.
NO_SCALING
.
value
,
is_2x
=
False
,
# unused
scale_dtype
=
jnp
.
float32
,
# unused
is_dbias
=
False
,
act_enum
=
act_type_id
,
act_len
=
act_len
,
is_outer
=
True
,
)
dbias
=
None
if
is_dbias
:
dbias
=
_jax_dbias
(
output
,
dtype
=
x
.
dtype
,
flatten_axis
=-
2
)
return
output
.
astype
(
x
.
dtype
),
dbias
if
quantizer
.
scaling_mode
==
ScalingMode
.
CURRENT_TENSOR_SCALING
:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out
=
dact_lu
(
...
...
@@ -1183,7 +1198,7 @@ def quantize_dact_dbias(
)
return
out
,
dbias
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
)
:
if
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
scale
=
quantizer
.
scale
# TE/common dact_dbias_quantize does not support gated act yet
...
...
@@ -1243,6 +1258,7 @@ def dact_lu(
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
quantizer
:
Optional
[
Quantizer
]
=
None
,
noop_scale_tensor
:
bool
=
False
,
)
->
Union
[
jnp
.
ndarray
,
ScaledTensor
]:
"""
Backward pass for activation with optional quantization.
...
...
@@ -1252,6 +1268,7 @@ def dact_lu(
x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
The gradient of the activation with respect to the input.
...
...
@@ -1262,5 +1279,6 @@ def dact_lu(
activation_type
=
activation_type
,
is_dbias
=
False
,
quantizer
=
quantizer
,
noop_scaled_tensor
=
noop_scale_tensor
,
)
return
output
transformer_engine/jax/cpp_extensions/gemm.py
View file @
44740c6c
...
...
@@ -3,19 +3,26 @@
# See LICENSE for license information.
"""JAX te modules"""
from
typing
import
Tuple
,
Sequence
,
Union
,
Dict
from
functools
import
partial
,
reduce
import
operator
import
math
import
operator
from
collections.abc
import
Iterable
from
typing
import
Tuple
,
Sequence
,
Union
from
functools
import
partial
,
reduce
import
jax
import
jax.numpy
as
jnp
from
transformer_engine_jax
import
get_device_compute_capability
,
get_num_compute_streams
from
jax
import
dtypes
from
jax.sharding
import
NamedSharding
,
PartitionSpec
from
jax.experimental.custom_partitioning
import
SdyShardingRule
import
transformer_engine_jax
as
tex
from
transformer_engine_jax
import
get_num_compute_streams
from
.base
import
BasePrimitive
,
register_primitive
from
.quantization
import
grouped_quantize
from
..quantize
import
(
ScaledTensor
,
ScaledTensor2x
,
GroupedScaledTensor1x
,
ScalingMode
,
Quantizer
,
...
...
@@ -24,10 +31,20 @@ from ..quantize import (
QuantizerSet
,
QuantizeLayout
,
noop_quantizer_set
,
is_fp8_gemm_with_all_layouts_supported
,
apply_padding_to_scale_inv
,
)
from
.misc
import
get_padded_spec
__all__
=
[
"gemm"
,
"grouped_gemm"
,
"is_gemm_with_all_layouts_supported"
]
__all__
=
[
"gemm"
,
"grouped_gemm"
,
"gemm_uses_jax_dot"
,
"sanitize_dims"
,
"get_non_contracting_dims"
,
"transpose_dims"
,
]
num_cublas_streams
=
get_num_compute_streams
()
...
...
@@ -35,14 +52,924 @@ num_cublas_streams = get_num_compute_streams()
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if
get_device_compute_capability
(
0
)
>=
90
:
if
tex
.
get_device_compute_capability
(
0
)
>=
90
:
return
33_554_432
return
4_194_304
def
is_gemm_with_all_layouts_supported
()
->
False
:
"""Return True if using blackwell, False otherwise."""
return
get_device_compute_capability
(
0
)
>=
100
def
sanitize_dims
(
ndim
:
int
,
dims
:
Union
[
int
,
Sequence
[
int
]])
->
Sequence
[
int
]:
"""Convert relative (negative) indexes to absolute dimension numbers."""
dims_
=
dims
if
isinstance
(
dims
,
Iterable
)
else
(
dims
,)
if
len
(
dims_
)
==
0
:
return
dims_
return
tuple
(
ndim
+
dim
if
dim
<
0
else
dim
for
dim
in
dims_
if
dim
is
not
None
)
def
get_non_contracting_dims
(
ndim
,
contracting_dims
):
"""Return a tuple of dimensions not included in the contracting dimensions."""
contracting_dims
=
sanitize_dims
(
ndim
,
contracting_dims
)
return
tuple
(
dim
for
dim
in
range
(
ndim
)
if
dim
not
in
contracting_dims
)
def
transpose_dims
(
ndim
,
dims_to_transpose
,
flatten_axis
=-
1
):
"""Compute the new dimension numbers after transpose."""
if
len
(
dims_to_transpose
)
==
0
:
return
dims_to_transpose
flatten_axis
=
ndim
-
flatten_axis
if
flatten_axis
>
0
else
flatten_axis
transposed_dims
=
(
*
range
(
flatten_axis
,
ndim
),
*
range
(
flatten_axis
))
return
tuple
(
transposed_dims
.
index
(
dim
)
for
dim
in
dims_to_transpose
)
def
_compatible_fp8_gemm_dtypes
(
lhs_dtype
,
rhs_dtype
)
->
bool
:
lhs
,
rhs
,
e4m3
,
e5m2
=
map
(
dtypes
.
canonicalize_dtype
,
(
lhs_dtype
,
rhs_dtype
,
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
,
),
)
# FP8 GEMM supports (e4m3 x e4m3), (e4m3 x e5m2) and (e5m2 x e4m3)
if
(
lhs
is
e4m3
and
rhs
in
(
e4m3
,
e5m2
))
or
(
lhs
in
(
e4m3
,
e5m2
)
and
rhs
is
e4m3
):
return
True
# Any other combination of data types is not supported
return
False
def
_get_gemm_layout
(
operand_ndims
:
Tuple
[
int
,
int
],
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
)
->
Tuple
[
bool
,
bool
]:
lhs_contracting
,
rhs_contracting
=
map
(
sanitize_dims
,
operand_ndims
,
contracting_dims
)
lhs_is_transposed
=
operand_ndims
[
0
]
-
1
not
in
lhs_contracting
rhs_is_transposed
=
operand_ndims
[
1
]
-
1
in
rhs_contracting
return
lhs_is_transposed
,
rhs_is_transposed
def
_quantize_gemm_operands
(
lhs
,
rhs
,
lhs_quantizer
,
rhs_quantizer
,
contracting_dims
):
lhs_q
=
lhs
rhs_q
=
rhs
if
not
isinstance
(
lhs
,
ScaledTensor
)
and
lhs_quantizer
is
not
None
:
lhs_cdims
=
sanitize_dims
(
lhs
.
ndim
,
contracting_dims
[
0
])
lhs_is_transposed
=
lhs
.
ndim
-
1
not
in
lhs_cdims
need_lhs_colwise
=
lhs_is_transposed
and
(
lhs_quantizer
.
scaling_mode
.
is_1d_block_scaling
()
or
not
is_fp8_gemm_with_all_layouts_supported
()
)
flatten_axis
=
max
(
lhs_cdims
)
+
1
if
lhs_is_transposed
else
min
(
lhs_cdims
)
lhs_q
=
lhs_quantizer
.
quantize
(
lhs
,
is_rowwise
=
not
need_lhs_colwise
,
is_colwise
=
need_lhs_colwise
,
flatten_axis
=
flatten_axis
,
)
if
not
isinstance
(
rhs
,
ScaledTensor
)
and
rhs_quantizer
is
not
None
:
rhs_cdims
=
sanitize_dims
(
rhs
.
ndim
,
contracting_dims
[
1
])
rhs_is_transposed
=
rhs
.
ndim
-
1
in
rhs_cdims
need_rhs_colwise
=
not
rhs_is_transposed
and
(
rhs_quantizer
.
scaling_mode
.
is_1d_block_scaling
()
or
not
is_fp8_gemm_with_all_layouts_supported
()
)
flatten_axis
=
min
(
rhs_cdims
)
if
rhs_is_transposed
else
max
(
rhs_cdims
)
+
1
rhs_q
=
rhs_quantizer
.
quantize
(
rhs
,
is_rowwise
=
not
need_rhs_colwise
,
is_colwise
=
need_rhs_colwise
,
flatten_axis
=
flatten_axis
,
)
assert
not
isinstance
(
lhs_q
,
ScaledTensor2x
)
assert
not
isinstance
(
rhs_q
,
ScaledTensor2x
)
return
lhs_q
,
rhs_q
class
GemmPrimitive
(
BasePrimitive
):
"""
Primitive for cuBLAS GEMM
"""
name
=
"te_gemm_ffi"
multiple_results
=
True
impl_static_args
=
(
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
,
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
):
del
lhs_quantized_colwise
,
rhs_quantized_colwise
,
use_split_accumulator
def
_dims_are_consecutive
(
dims
):
if
len
(
dims
)
<=
1
:
return
True
return
sorted
(
dims
)
==
list
(
range
(
min
(
dims
),
max
(
dims
)
+
1
))
# Sanity-check operand layouts and types
operand_ndims
=
(
lhs
.
ndim
,
rhs
.
ndim
)
(
lhs_contracting_dims
,
rhs_contracting_dims
,
)
=
map
(
sanitize_dims
,
operand_ndims
,
contracting_dims
)
assert
_dims_are_consecutive
(
lhs_contracting_dims
),
(
"cuBLAS GEMM expected consecutive contracting dimensions for LHS operand, but got "
f
"
{
lhs_contracting_dims
}
."
)
assert
_dims_are_consecutive
(
rhs_contracting_dims
),
(
"cuBLAS GEMM expected consecutive contracting dimensions for RHS operand, but got "
f
"
{
rhs_contracting_dims
}
."
)
(
lhs_batch_dims
,
rhs_batch_dims
,
)
=
map
(
sanitize_dims
,
operand_ndims
,
batched_dims
)
assert
_dims_are_consecutive
(
lhs_batch_dims
),
(
"cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got "
f
"
{
lhs_batch_dims
}
."
)
assert
_dims_are_consecutive
(
rhs_batch_dims
),
(
"cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got "
f
"
{
rhs_batch_dims
}
."
)
if
len
(
lhs_batch_dims
)
==
0
:
assert
(
len
(
rhs_batch_dims
)
==
0
),
"cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched."
elif
len
(
rhs_batch_dims
)
!=
0
:
assert
all
(
bdim
in
lhs_contracting_dims
for
bdim
in
lhs_batch_dims
)
and
all
(
bdim
in
rhs_contracting_dims
for
bdim
in
rhs_batch_dims
),
"cuBLAS GEMM batched dimensions must be contracting when both operands are batched."
lhs_contracting_size
,
rhs_contracting_size
=
map
(
lambda
shape
,
dims
:
reduce
(
operator
.
mul
,
[
shape
[
dim
]
for
dim
in
dims
]),
(
lhs
.
shape
,
rhs
.
shape
),
(
lhs_contracting_dims
,
rhs_contracting_dims
),
)
assert
lhs_contracting_size
==
rhs_contracting_size
,
(
"cuBLAS GEMM operands have incompatible contracting dimensions: "
f
"
{
lhs
.
shape
}
@ idx
{
lhs_contracting_dims
}
X
{
rhs
.
shape
}
@ idx
{
rhs_contracting_dims
}
."
)
lhs_is_transposed
,
rhs_is_transposed
=
_get_gemm_layout
(
operand_ndims
,
contracting_dims
)
if
scaling_mode
!=
ScalingMode
.
NO_SCALING
:
assert
_compatible_fp8_gemm_dtypes
(
lhs
.
dtype
,
rhs
.
dtype
),
(
"cuBLAS GEMM quantized operands have incompatible data types: "
f
"
{
lhs
.
dtype
}
x
{
rhs
.
dtype
}
."
)
assert
(
lhs_scale_inv
.
size
>
0
and
rhs_scale_inv
.
size
>
0
),
"Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
if
(
scaling_mode
!=
ScalingMode
.
MXFP8_1D_SCALING
and
not
tex
.
is_non_nt_fp8_gemm_supported
()
):
assert
not
lhs_is_transposed
and
rhs_is_transposed
,
(
"cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
"require non-transposed LHS and transposed RHS operands "
"(`contracting_dims=((-1, ), (-1, ))`)."
)
# Determine output shape and dtype
assert
(
dtypes
.
canonicalize_dtype
(
out_dtype
).
itemsize
>
1
),
"cuBLAS GEMM custom op does not support 8-bit quantized output types."
lhs_non_contracting_shape
,
rhs_non_contracting_shape
=
map
(
lambda
shape
,
dims
:
[
shape
[
dim
]
for
dim
in
range
(
len
(
shape
))
if
dim
not
in
dims
],
(
lhs
.
shape
,
rhs
.
shape
),
(
lhs_contracting_dims
,
rhs_contracting_dims
),
)
out_shape
=
(
*
lhs_non_contracting_shape
,
*
rhs_non_contracting_shape
)
output
=
jax
.
core
.
ShapedArray
(
shape
=
out_shape
,
dtype
=
out_dtype
)
# Validate bias
bias_shape
=
(
0
,)
bias_dtype
=
out_dtype
if
fuse_bias
:
expected_bias_size
=
reduce
(
operator
.
mul
,
rhs_non_contracting_shape
)
if
not
grad
:
assert
bias
.
size
==
expected_bias_size
,
(
"cuBLAS GEMM bias tensor has incorrect shape, "
f
"expected (
{
expected_bias_size
}
, ) but found
{
bias
.
shape
}
."
)
assert
bias
.
dtype
==
out_dtype
,
(
"cuBLAS GEMM bias tensor has incorrect data type, "
f
"expected
{
bias_dtype
}
but found
{
bias
.
dtype
}
."
)
bias_shape
=
bias
.
shape
else
:
bias_shape
=
rhs_non_contracting_shape
bias_grad
=
jax
.
core
.
ShapedArray
(
shape
=
bias_shape
,
dtype
=
bias_dtype
)
# Validate pre-GeLU
pre_gelu_shape
=
(
0
,)
pre_gelu_dtype
=
out_dtype
if
fuse_gelu
:
pre_gelu_shape
=
out_shape
if
grad
:
pre_gelu_ndim
=
len
(
pre_gelu_shape
)
assert
gelu_input
.
ndim
==
pre_gelu_shape
and
all
(
gelu_input
.
shape
[
i
]
==
pre_gelu_shape
[
i
]
for
i
in
range
(
pre_gelu_ndim
)
),
(
"cuBLAS GEMM pre-GeLU tensor has incorrect shape, "
f
"expected
{
pre_gelu_shape
}
but found
{
gelu_input
.
shape
}
."
)
assert
gelu_input
.
dtype
==
out_dtype
,
(
"cuBLAS GEMM pre-GeLU tensor has incorrect data type, "
f
"expected
{
pre_gelu_dtype
}
but found
{
gelu_input
.
dtype
}
."
)
pre_gelu_out
=
jax
.
core
.
ShapedArray
(
shape
=
pre_gelu_shape
,
dtype
=
pre_gelu_dtype
)
# Need extra workspace for swizzled scale factors
lhs_swizzle_size
=
0
rhs_swizzle_size
=
0
swizzle_dtype
=
jnp
.
uint8
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
lhs_swizzle_size
=
lhs_scale_inv
.
size
rhs_swizzle_size
=
rhs_scale_inv
.
size
lhs_swizzle
=
jax
.
core
.
ShapedArray
(
shape
=
(
lhs_swizzle_size
,),
dtype
=
swizzle_dtype
)
rhs_swizzle
=
jax
.
core
.
ShapedArray
(
shape
=
(
rhs_swizzle_size
,),
dtype
=
swizzle_dtype
)
# Declare cuBLAS workspace
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size
=
get_cublas_workspace_size_bytes
()
+
256
workspace
=
jax
.
core
.
ShapedArray
(
shape
=
(
workspace_size
,),
dtype
=
jnp
.
uint8
)
return
output
,
bias_grad
,
pre_gelu_out
,
lhs_swizzle
,
rhs_swizzle
,
workspace
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
outputs
=
GemmPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
outputs
[:
-
3
]
# discard workspace arrays
@
staticmethod
def
lowering
(
ctx
,
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
,
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
):
del
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
out_dtype
lhs_aval
,
_
,
rhs_aval
,
*
_
=
ctx
.
avals_in
lhs_cdims
,
rhs_cdims
=
map
(
sanitize_dims
,
(
lhs_aval
.
ndim
,
rhs_aval
.
ndim
),
contracting_dims
)
lhs_transposed
,
rhs_transposed
=
_get_gemm_layout
(
(
lhs_aval
.
ndim
,
rhs_aval
.
ndim
),
(
lhs_cdims
,
rhs_cdims
)
)
args
=
(
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
)
kwargs
=
{
"scaling_mode"
:
int
(
scaling_mode
.
value
),
"lhs_axis_boundary"
:
max
(
lhs_cdims
)
+
1
if
lhs_transposed
else
min
(
lhs_cdims
),
"rhs_axis_boundary"
:
min
(
rhs_cdims
)
if
rhs_transposed
else
max
(
rhs_cdims
)
+
1
,
"lhs_transposed"
:
lhs_transposed
,
"rhs_transposed"
:
rhs_transposed
,
"fuse_bias"
:
fuse_bias
,
"fuse_gelu"
:
fuse_gelu
,
"grad"
:
grad
,
"use_split_accumulator"
:
use_split_accumulator
,
}
operand_output_aliases
=
{}
if
fuse_bias
and
not
grad
:
operand_output_aliases
.
update
({
4
:
1
})
# bias <-> bias_grad
if
fuse_gelu
and
grad
:
operand_output_aliases
.
update
({
5
:
2
})
# gelu_input <-> pre_gelu_out
return
jax
.
ffi
.
ffi_lowering
(
GemmPrimitive
.
name
,
operand_output_aliases
=
operand_output_aliases
,
)(
ctx
,
*
args
,
**
kwargs
)
@
staticmethod
def
impl
(
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
,
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
):
lhs_cdims
,
rhs_cdims
=
map
(
sanitize_dims
,
(
lhs
.
ndim
,
rhs
.
ndim
),
contracting_dims
)
lhs_transposed
,
rhs_transposed
=
_get_gemm_layout
(
(
lhs
.
ndim
,
rhs
.
ndim
),
(
lhs_cdims
,
rhs_cdims
)
)
lhs_scale_inv
=
apply_padding_to_scale_inv
(
lhs_scale_inv
,
scaling_mode
,
lhs
.
shape
,
is_colwise
=
lhs_quantized_colwise
,
flatten_axis
=
max
(
lhs_cdims
)
+
1
if
lhs_transposed
else
min
(
lhs_cdims
),
)
rhs_scale_inv
=
apply_padding_to_scale_inv
(
rhs_scale_inv
,
scaling_mode
,
rhs
.
shape
,
is_colwise
=
rhs_quantized_colwise
,
flatten_axis
=
min
(
rhs_cdims
)
if
rhs_transposed
else
max
(
rhs_cdims
)
+
1
,
)
outputs
=
GemmPrimitive
.
inner_primitive
.
bind
(
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
,
out_dtype
=
out_dtype
,
contracting_dims
=
contracting_dims
,
batched_dims
=
batched_dims
,
lhs_quantized_colwise
=
lhs_quantized_colwise
,
rhs_quantized_colwise
=
rhs_quantized_colwise
,
scaling_mode
=
scaling_mode
,
fuse_bias
=
fuse_bias
,
fuse_gelu
=
fuse_gelu
,
grad
=
grad
,
use_split_accumulator
=
use_split_accumulator
,
)
return
outputs
[:
-
3
]
# discard workspace arrays
@
staticmethod
def
batcher
(
batched_args
,
jax_batch_dims
,
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
):
assert
GemmPrimitive
.
outer_primitive
is
not
None
lhs
,
_
,
rhs
,
*
_
=
batched_args
lhs_bdims
,
_
,
rhs_bdims
,
*
_
=
jax_batch_dims
arg_lhs_bdims
,
arg_rhs_bdims
=
map
(
sanitize_dims
,
(
lhs
.
ndim
,
rhs
.
ndim
),
batched_dims
)
arg_lhs_bdims
=
(
None
,)
if
len
(
arg_lhs_bdims
)
==
0
else
arg_lhs_bdims
assert
all
(
bdim
==
arg_bdim
for
bdim
,
arg_bdim
in
zip
(
lhs_bdims
,
arg_lhs_bdims
)),
(
"User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch "
f
"dimensions inferred by JAX/XLA, expected
{
lhs_bdims
}
but got
{
arg_lhs_bdims
}
."
)
arg_rhs_bdims
=
(
None
,)
if
len
(
arg_rhs_bdims
)
==
0
else
arg_rhs_bdims
assert
all
(
bdim
==
arg_bdim
for
bdim
,
arg_bdim
in
zip
(
rhs_bdims
,
arg_rhs_bdims
)),
(
"User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch "
f
"dimensions inferred by JAX/XLA, expected
{
lhs_bdims
}
but got
{
arg_lhs_bdims
}
."
)
# Output is batched like the non-contracting batch dimensions of the LHS operand
lhs_cdims
=
sanitize_dims
(
lhs
.
ndim
,
contracting_dims
)
lhs_non_contracting_bdims
=
tuple
(
dim
for
dim
in
lhs_bdims
if
dim
not
in
lhs_cdims
)
out_bdims
=
(
None
,)
if
len
(
lhs_non_contracting_bdims
)
==
0
else
lhs_non_contracting_bdims
# Bias gradient is never batched
bias_bdims
=
(
None
,)
# Pre-GeLU output, if exists, is batched like GEMM output
pre_gelu_bdims
=
(
None
,)
if
fuse_gelu
and
not
grad
:
pre_gelu_bdims
=
out_bdims
return
(
GemmPrimitive
.
outer_primitive
.
bind
(
*
batched_args
,
out_dtype
=
out_dtype
,
contracting_dims
=
contracting_dims
,
batched_dims
=
batched_dims
,
lhs_quantized_colwise
=
lhs_quantized_colwise
,
rhs_quantized_colwise
=
rhs_quantized_colwise
,
scaling_mode
=
scaling_mode
,
fuse_bias
=
fuse_bias
,
fuse_gelu
=
fuse_gelu
,
grad
=
grad
,
use_split_accumulator
=
use_split_accumulator
,
),
(
out_bdims
,
bias_bdims
,
pre_gelu_bdims
),
)
@
staticmethod
def
_decompose_operand_specs
(
specs
,
contracting_dims
,
batch_dims
):
ndims
=
len
(
specs
)
cdims
,
bdims
=
map
(
sanitize_dims
,
(
ndims
,
ndims
),
(
contracting_dims
,
batch_dims
))
# Batch specs
bspecs
=
tuple
(
specs
[
i
]
for
i
in
bdims
)
# Non-batch leading dimension specs
lspecs
=
tuple
(
specs
[
i
]
for
i
in
range
(
ndims
)
if
i
not
in
cdims
+
bdims
)
# Non-batch contracting dimension specs
cspecs
=
tuple
(
specs
[
i
]
for
i
in
range
(
ndims
)
if
i
in
cdims
and
i
not
in
bdims
)
return
bspecs
,
lspecs
,
cspecs
@
staticmethod
def
_parse_operand_output_specs
(
arg_infos
,
contracting_dims
,
batched_dims
):
lhs_specs
,
_
,
rhs_specs
,
*
_
=
map
(
get_padded_spec
,
arg_infos
)
lhs_ndim
,
rhs_ndim
=
map
(
len
,
(
lhs_specs
,
rhs_specs
))
lhs_cdims
,
rhs_cdims
,
lhs_bdims
,
rhs_bdims
=
map
(
sanitize_dims
,
2
*
[
lhs_ndim
,
rhs_ndim
],
contracting_dims
+
batched_dims
)
(
(
lhs_bspecs
,
lhs_lspecs
,
lhs_cspecs
),
(
rhs_bspecs
,
rhs_lspecs
,
rhs_cspecs
),
)
=
map
(
GemmPrimitive
.
_decompose_operand_specs
,
(
lhs_specs
,
rhs_specs
),
(
lhs_cdims
,
rhs_cdims
),
(
lhs_bdims
,
rhs_bdims
),
)
# Batched dimensions must have the same sharding
if
len
(
lhs_bdims
)
>
0
and
len
(
rhs_bdims
)
>
0
:
assert
all
(
lhs_bspec
==
rhs_bspec
for
lhs_bspec
,
rhs_bspec
in
zip
(
lhs_bspecs
,
rhs_bspecs
)
),
(
"cuBLAS GEMM operand batch dimensions must have the same sharding: "
f
"
{
lhs_specs
}
@ idx
{
lhs_bdims
}
x
{
rhs_specs
}
@ idx
{
rhs_bdims
}
."
)
# Only one each of the non-batched leading dimensions and non-batched contracting
# dimensions can be sharded
lhs_ldims
,
rhs_ldims
=
map
(
lambda
ndim
,
exclude
:
tuple
(
dim
for
dim
in
range
(
ndim
)
if
dim
not
in
exclude
),
(
lhs_ndim
,
rhs_ndim
),
(
lhs_bdims
+
lhs_cdims
,
rhs_bdims
+
rhs_cdims
),
)
(
lhs_lspec_not_none
,
rhs_lspec_not_none
,
lhs_cspec_not_none
,
rhs_cspec_not_none
)
=
map
(
lambda
specs
:
tuple
(
spec
for
spec
in
specs
if
spec
is
not
None
),
(
lhs_lspecs
,
rhs_lspecs
,
lhs_cspecs
,
rhs_cspecs
),
)
assert
len
(
lhs_lspec_not_none
)
<=
1
and
len
(
rhs_lspec_not_none
)
<=
1
,
(
"cuBLAS GEMM operands can have only one sharded non-batched leading dimension: "
f
"
{
lhs_specs
}
@ idx
{
lhs_ldims
}
x
{
rhs_specs
}
@ idx
{
rhs_ldims
}
."
)
assert
len
(
lhs_cspec_not_none
)
<=
1
and
len
(
rhs_cspec_not_none
)
<=
1
,
(
"cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: "
f
"
{
lhs_specs
}
@ idx
{
lhs_cdims
}
x
{
rhs_specs
}
@ idx
{
rhs_cdims
}
."
)
# Extract single leading and contracting dimension specs
(
lhs_lspec
,
rhs_lspec
,
lhs_cspec
,
rhs_cspec
)
=
map
(
lambda
specs
:
None
if
len
(
specs
)
==
0
else
specs
[
0
],
(
lhs_lspec_not_none
,
rhs_lspec_not_none
,
lhs_cspec_not_none
,
rhs_cspec_not_none
),
)
# Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts
# with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands.
# 1. K1 == K2 != None and N == None
# LHS: (B, M, K)
# RHS: (B, None, K)
# OUT: (B, M, None) --(AR)-> (B, M, None)
# 2. K1 == K2 != None and M == N != None
# LHS: (B, M, K)
# RHS: (B, N, K)--(AG)->(B, None, K)
# OUT: (B, M, None) --(RS)--> (B, M, N)
# 3. M == N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, M, K)--(AG)->(B, None, None)
# OUT: (B, M, None)
# 4. M != N
# LHS: (B, M, K)--(AG)->(B, M, None)
# RHS: (B, N, K)--(AG)->(B, N, None)
# OUT: (B, M, N)
reduce_flag
=
lhs_cspec
is
not
None
and
lhs_cspec
==
rhs_cspec
all_reduce_output
=
reduce_flag
and
rhs_lspec
is
None
reduce_scatter_output
=
reduce_flag
and
lhs_lspec
is
not
None
and
lhs_lspec
==
rhs_lspec
all_reduce_spec
=
reduce_scatter_spec
=
scatter_dim
=
None
lhs_non_contracting_specs
,
rhs_non_contracting_specs
=
map
(
lambda
specs
,
cdims
:
tuple
(
specs
[
i
]
for
i
in
range
(
len
(
specs
))
if
i
not
in
cdims
),
(
lhs_specs
,
rhs_specs
),
(
lhs_cdims
,
rhs_cdims
),
)
out_specs
=
(
*
lhs_non_contracting_specs
,
*
rhs_non_contracting_specs
)
if
reduce_scatter_output
:
# All-gather (if necessary) the non-batch non-contracting dimension of RHS
# (B, N, K) --(AG)-> (B, None, K)
# (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N)
rhs_spec
=
tuple
(
rhs_spec
[
i
]
if
i
in
set
(
rhs_bdims
+
rhs_cdims
)
else
None
for
i
in
range
(
rhs_ndim
)
)
reduce_scatter_spec
=
lhs_cspec
scatter_dim
=
out_specs
.
index
(
rhs_lspec
)
elif
all_reduce_output
:
# Set all output trailing dimensions to zero
out_specs
=
(
*
lhs_non_contracting_specs
,
*
[
None
for
_
in
range
(
len
(
rhs_non_contracting_specs
))],
)
all_reduce_spec
=
lhs_cspec
else
:
# All-gather (if necessary) the non-batch contracting dimensions
# (B, M, K) --(AG)-> (B, M, None)
# (B, N, K) --(AG)-> (B, N, None)
# (B, M, None) x (B, N, None)^T = (B, M, N)
lhs_specs
=
tuple
(
None
if
i
in
lhs_cdims
and
i
not
in
lhs_bdims
else
lhs_specs
[
i
]
for
i
in
range
(
lhs_ndim
)
)
rhs_specs
=
tuple
(
None
if
i
in
rhs_cdims
and
i
not
in
rhs_bdims
else
rhs_specs
[
i
]
for
i
in
range
(
rhs_ndim
)
)
# Check if RHS non-contracting spec also appears in the LHS non-contracting specs
if
rhs_lspec
is
not
None
and
rhs_lspec
in
tuple
(
lhs_specs
[
i
]
for
i
in
range
(
lhs_ndim
)
if
i
not
in
lhs_cdims
):
# All-gather (if necessary) the non-batch non-contracting dimensions of RHS
# (B, N, None) --(AG)-> (B, None, None)
# (B, M, None) x (B, None, None)^T = (B, M, None)
rhs_specs
=
tuple
(
None
if
i
not
in
set
(
rhs_bdims
+
rhs_cdims
)
else
rhs_specs
[
i
]
for
i
in
range
(
rhs_ndim
)
)
# Set all output trailing dimensions to zero
out_specs
=
(
*
lhs_non_contracting_specs
,
*
[
None
for
_
in
range
(
len
(
rhs_non_contracting_specs
))],
)
# Bias and Pre-GeLU sharding is based on GEMM output
bias_specs
=
out_specs
[
len
(
lhs_non_contracting_specs
)
:]
gelu_specs
=
out_specs
return
(
(
lhs_specs
,
rhs_specs
,
bias_specs
,
gelu_specs
),
(
out_specs
,
bias_specs
,
gelu_specs
),
all_reduce_spec
,
reduce_scatter_spec
,
scatter_dim
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
mesh
,
arg_infos
,
result_infos
,
):
del
(
out_dtype
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
grad
,
)
del
use_split_accumulator
,
result_infos
(
_
,
(
out_specs
,
dbias_specs
,
pre_gelu_specs
),
*
_
)
=
(
GemmPrimitive
.
_parse_operand_output_specs
(
arg_infos
,
contracting_dims
,
batched_dims
)
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_specs
))
# Discard bias gradient spec if there is no bias fusion
if
not
fuse_bias
:
dbias_specs
=
(
None
,)
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
dbias_specs
))
# Discard pre-GeLU output spec if there is no GeLU fusion
if
not
fuse_gelu
:
pre_gelu_specs
=
(
None
,)
pre_gelu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
pre_gelu_specs
))
return
[
out_sharding
,
dbias_sharding
,
pre_gelu_sharding
]
@
staticmethod
def
partition
(
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
mesh
,
arg_infos
,
result_infos
,
):
del
result_infos
(
(
lhs_specs
,
rhs_specs
,
bias_input_specs
,
gelu_input_specs
),
(
out_specs
,
dbias_specs
,
pre_gelu_specs
),
all_reduce_spec
,
reduce_scatter_spec
,
scatter_dim
,
)
=
GemmPrimitive
.
_parse_operand_output_specs
(
arg_infos
,
contracting_dims
,
batched_dims
)
# Assemble argument shardings
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
none_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
lhs_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
lhs_specs
))
rhs_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
rhs_specs
))
arg_shardings
=
(
lhs_sharding
,
lhs_sharding
if
scaling_mode
.
is_1d_block_scaling
()
else
none_sharding
,
rhs_sharding
,
rhs_sharding
if
scaling_mode
.
is_1d_block_scaling
()
else
none_sharding
,
)
# Discard bias input spec if there is no bias fusion
if
not
fuse_bias
:
bias_input_specs
=
(
None
,)
arg_shardings
+=
(
NamedSharding
(
mesh
,
PartitionSpec
(
*
bias_input_specs
)),)
# Discard pre-GeLU input spec if there is no GeLU fusion
if
not
fuse_gelu
:
gelu_input_specs
=
(
None
,)
arg_shardings
+=
(
NamedSharding
(
mesh
,
PartitionSpec
(
*
gelu_input_specs
)),)
# Assemble output shardings
out_shardings
=
[
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_specs
))]
# Discard bias gradient spec if there is no bias fusion
if
not
fuse_bias
:
dbias_specs
=
(
None
,)
out_shardings
.
append
(
NamedSharding
(
mesh
,
PartitionSpec
(
*
dbias_specs
)))
# Discard pre-GeLU output spec if there is no GeLU fusion
if
not
fuse_gelu
:
pre_gelu_specs
=
(
None
,)
out_shardings
.
append
(
NamedSharding
(
mesh
,
PartitionSpec
(
*
pre_gelu_specs
)))
def
_sharded_impl
(
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
):
outputs
=
GemmPrimitive
.
impl
(
lhs
,
lhs_scale_inv
,
rhs
,
rhs_scale_inv
,
bias
,
gelu_input
,
out_dtype
=
out_dtype
,
contracting_dims
=
contracting_dims
,
batched_dims
=
batched_dims
,
lhs_quantized_colwise
=
lhs_quantized_colwise
,
rhs_quantized_colwise
=
rhs_quantized_colwise
,
scaling_mode
=
scaling_mode
,
fuse_bias
=
fuse_bias
,
fuse_gelu
=
fuse_gelu
,
grad
=
grad
,
use_split_accumulator
=
use_split_accumulator
,
)
# All-Reduce/Reduce-Scatter GEMM output
if
all_reduce_spec
is
not
None
:
outputs
[
0
]
=
jax
.
lax
.
psum
(
outputs
[
0
],
all_reduce_spec
)
if
fuse_gelu
and
not
grad
:
outputs
[
2
]
=
jax
.
lax
.
psum
(
outputs
[
2
],
all_reduce_spec
)
elif
reduce_scatter_spec
is
not
None
:
outputs
[
0
]
=
jax
.
lax
.
psum_scatter
(
outputs
[
0
],
reduce_scatter_spec
,
scatter_dimension
=
scatter_dim
,
tiled
=
True
)
if
fuse_gelu
and
not
grad
:
outputs
[
2
]
=
jax
.
lax
.
psum_scatter
(
outputs
[
2
],
reduce_scatter_spec
,
scatter_dimension
=
scatter_dim
,
tiled
=
True
)
return
outputs
return
mesh
,
_sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
out_dtype
,
contracting_dims
,
batched_dims
,
lhs_quantized_colwise
,
rhs_quantized_colwise
,
scaling_mode
,
fuse_bias
,
fuse_gelu
,
grad
,
use_split_accumulator
,
mesh
,
operand_types
,
result_types
,
):
del
lhs_quantized_colwise
,
rhs_quantized_colwise
,
out_dtype
,
grad
,
use_split_accumulator
del
mesh
,
result_types
prefix
=
"GemmPrimitive_"
def
_generate_operand_rules
(
name
,
ndim
,
cdims
,
bdims
):
specs
=
[]
ldims
=
tuple
(
i
for
i
in
range
(
ndim
)
if
i
not
in
bdims
+
cdims
)
for
i
in
range
(
ndim
):
dim_name
=
None
if
i
in
bdims
:
dim_idx
=
bdims
.
index
(
i
)
if
len
(
bdims
)
>
1
else
""
dim_name
=
f
"b
{
dim_idx
}
"
elif
i
in
cdims
:
dim_idx
=
cdims
.
index
(
i
)
if
len
(
cdims
)
>
1
else
""
dim_name
=
f
"k
{
dim_idx
}
"
else
:
dim_idx
=
ldims
.
index
(
i
)
if
len
(
ldims
)
>
1
else
""
dim_name
=
f
"
{
name
}
_l
{
dim_idx
}
"
specs
.
append
(
prefix
+
dim_name
)
return
specs
lhs
,
_
,
rhs
,
*
_
=
operand_types
operand_ndims
=
(
len
(
lhs
.
shape
),
len
(
rhs
.
shape
))
(
lhs_cdims
,
rhs_cdims
),
(
lhs_bdims
,
rhs_bdims
)
=
map
(
lambda
dims
:
map
(
sanitize_dims
,
operand_ndims
,
dims
),
(
contracting_dims
,
batched_dims
),
)
lhs_specs
,
rhs_specs
=
map
(
_generate_operand_rules
,
(
"lhs"
,
"rhs"
),
operand_ndims
,
(
lhs_cdims
,
rhs_cdims
),
(
lhs_bdims
,
rhs_bdims
),
)
lhs_scale_specs
=
(
"…1"
,)
rhs_scale_specs
=
(
"…2"
,)
if
scaling_mode
.
is_1d_block_scaling
():
# Shardy rules for MXFP8 scales cannot be related to the operands because of the
# global-unpadding and local-padding workflow. This can potentially insert expensive
# re-shards in the partition call later if the scales are not already sharded correctly.
lhs_scale_specs
,
rhs_scale_specs
=
map
(
lambda
specs
:
tuple
(
spec
.
replace
(
prefix
,
prefix
+
"scale_inv_"
)
for
spec
in
specs
),
(
lhs_specs
,
rhs_specs
),
)
lhs_non_cspec
=
tuple
(
lhs_specs
[
i
]
for
i
in
range
(
operand_ndims
[
0
])
if
i
not
in
lhs_cdims
)
rhs_non_cspec
=
tuple
(
rhs_specs
[
i
]
for
i
in
range
(
operand_ndims
[
1
])
if
i
not
in
rhs_cdims
)
out_spec
=
(
*
lhs_non_cspec
,
*
rhs_non_cspec
)
bias_spec
=
rhs_non_cspec
if
fuse_bias
else
(
"…4"
,)
gelu_spec
=
out_spec
if
fuse_gelu
else
(
"…5"
,)
return
SdyShardingRule
(
operand_mappings
=
(
lhs_specs
,
lhs_scale_specs
,
rhs_specs
,
rhs_scale_specs
,
bias_spec
,
gelu_spec
,
),
result_mappings
=
(
out_spec
,
bias_spec
,
gelu_spec
,
),
)
register_primitive
(
GemmPrimitive
)
def
gemm_uses_jax_dot
()
->
bool
:
"""Check if the GEMM call directs to the TE custom cuBLAS call or native JAX dot."""
return
not
GemmPrimitive
.
enabled
()
def
_te_gemm
(
lhs
:
Union
[
jax
.
Array
,
ScaledTensor
],
rhs
:
Union
[
jax
.
Array
,
ScaledTensor
],
bias
:
jax
.
Array
=
None
,
gelu_input
:
jax
.
Array
=
None
,
lhs_quantizer
:
Quantizer
=
None
,
rhs_quantizer
:
Quantizer
=
None
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
-
1
,),
(
0
,)),
batched_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((),
()),
fuse_bias
:
bool
=
False
,
fuse_gelu
:
bool
=
False
,
grad
:
bool
=
False
,
use_split_accumulator
:
bool
=
QuantizeConfig
.
FP8_2X_ACC_FPROP
,
)
->
Tuple
[
jax
.
Array
,
...]:
# Prepare non-quantized GEMM operands
lhs_data
=
lhs
rhs_data
=
rhs
lhs_scale_inv
=
jnp
.
empty
(
0
,
dtype
=
jnp
.
float32
)
rhs_scale_inv
=
jnp
.
empty
(
0
,
dtype
=
jnp
.
float32
)
scaling_mode
=
ScalingMode
.
NO_SCALING
lhs_is_transposed
,
rhs_is_transposed
=
_get_gemm_layout
((
lhs
.
ndim
,
rhs
.
ndim
),
contracting_dims
)
lhs_cdims
,
rhs_cdims
=
map
(
sanitize_dims
,
(
lhs
.
ndim
,
rhs
.
ndim
),
contracting_dims
)
lhs_bdims
,
rhs_bdims
=
map
(
sanitize_dims
,
(
lhs
.
ndim
,
rhs
.
ndim
),
batched_dims
)
# Quantize operands (if necessary)
lhs_q
,
rhs_q
=
_quantize_gemm_operands
(
lhs
,
rhs
,
lhs_quantizer
,
rhs_quantizer
,
contracting_dims
)
# Extract GEMM custom op inputs from quantized operands
if
isinstance
(
lhs_q
,
ScaledTensor
):
assert
isinstance
(
rhs_q
,
ScaledTensor
)
or
rhs_quantizer
is
not
None
,
(
"cuBLAS GEMM with quantized LHS and non-quantized RHS operands requires a valid "
"`Quantizer` object to quantize the RHS operand."
)
if
isinstance
(
lhs_q
,
ScaledTensor2x
):
# Choose the quantization of the contracting dimension(s)
lhs_q
=
lhs_q
.
get_colwise_tensor
()
if
lhs_is_transposed
else
lhs_q
.
get_rowwise_tensor
()
scaling_mode
=
lhs_q
.
scaling_mode
lhs_data
=
lhs_q
.
data
lhs_scale_inv
=
lhs_q
.
scale_inv
if
lhs_q
.
data_layout
==
"T"
:
lhs_cdims
=
transpose_dims
(
lhs_q
.
ndim
,
lhs_cdims
,
flatten_axis
=
lhs_q
.
flatten_axis
)
lhs_bdims
=
transpose_dims
(
lhs_q
.
ndim
,
lhs_bdims
,
flatten_axis
=
lhs_q
.
flatten_axis
)
if
isinstance
(
rhs_q
,
ScaledTensor
):
assert
isinstance
(
lhs_q
,
ScaledTensor
)
or
lhs_quantizer
is
not
None
,
(
"cuBLAS GEMM with non-quantized LHS and quantized RHS operands requires a valid "
"`Quantizer` object to quantize the LHS operand."
)
if
isinstance
(
rhs_q
,
ScaledTensor2x
):
# Choose the quantization of the contracting dimension(s)
rhs_q
=
rhs_q
.
get_rowwise_tensor
()
if
rhs_is_transposed
else
rhs_q
.
get_colwise_tensor
()
assert
rhs_q
.
scaling_mode
==
lhs_q
.
scaling_mode
,
(
"cuBLAS GEMM quantized operands have mismatched scaling types, "
f
"LHS:
{
lhs_q
.
scaling_mode
}
x RHS:
{
rhs_q
.
scaling_mode
}
."
)
rhs_data
=
rhs_q
.
data
rhs_scale_inv
=
rhs_q
.
scale_inv
if
rhs_q
.
data_layout
==
"T"
:
rhs_cdims
=
transpose_dims
(
rhs_q
.
ndim
,
rhs_cdims
,
flatten_axis
=
rhs_q
.
flatten_axis
)
rhs_bdims
=
transpose_dims
(
rhs_q
.
ndim
,
rhs_bdims
,
flatten_axis
=
rhs_q
.
flatten_axis
)
# Dummy empties for bias and gelu
out_dtype
=
lhs_q
.
dq_dtype
if
isinstance
(
lhs_q
,
ScaledTensor
)
else
lhs_data
.
dtype
if
bias
is
None
or
not
(
fuse_bias
and
not
grad
):
bias
=
jnp
.
empty
(
0
,
dtype
=
out_dtype
)
if
gelu_input
is
None
or
not
(
fuse_gelu
and
grad
):
gelu_input
=
jnp
.
empty
(
0
,
dtype
=
out_dtype
)
return
GemmPrimitive
.
outer_primitive
.
bind
(
lhs_data
,
lhs_scale_inv
,
rhs_data
,
rhs_scale_inv
,
bias
,
gelu_input
,
out_dtype
=
out_dtype
,
contracting_dims
=
(
lhs_cdims
,
rhs_cdims
),
batched_dims
=
(
lhs_bdims
,
rhs_bdims
),
lhs_quantized_colwise
=
lhs_q
.
is_colwise
if
isinstance
(
lhs_q
,
ScaledTensor
)
else
False
,
rhs_quantized_colwise
=
rhs_q
.
is_colwise
if
isinstance
(
rhs_q
,
ScaledTensor
)
else
False
,
scaling_mode
=
scaling_mode
,
fuse_bias
=
fuse_bias
,
fuse_gelu
=
fuse_gelu
,
grad
=
grad
,
use_split_accumulator
=
use_split_accumulator
,
)
class
GroupedGemmPrimitive
(
BasePrimitive
):
...
...
@@ -102,15 +1029,28 @@ class GroupedGemmPrimitive(BasePrimitive):
A jnp.ndarray containing the result of the grouped GEMM operation
"""
del
lhs_data_aval
,
rhs_data_aval
,
bias_aval
,
group_offset_aval
del
K
,
lhs_is_trans
,
rhs_is_trans
,
scaling_mode
,
has_bias
del
lhs_scale_inv_aval
,
rhs_scale_inv_aval
del
K
,
lhs_is_trans
,
rhs_is_trans
,
has_bias
# TODO(Phuong): move some shape checks from Cpp to here
workspace_size
=
get_cublas_workspace_size_bytes
()
*
num_cublas_streams
# JAX buffer pointers are 128-aligned
# 255 is added to the workspace size to ensure workspace ptr is 256-aligned
workspace_size
+=
255
workspace_alignment_padding
=
256
tensor_scaling_sinv_aligment
=
16
mxfp8_scaling_sinv_alignment_padding
=
256
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size
+=
workspace_alignment_padding
if
scaling_mode
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
,
ScalingMode
.
CURRENT_TENSOR_SCALING
.
value
,
):
# For tensor scaling, each matrix has a single scale value, but it
# needs to be aligned to 16 bytes for CUDA 12.9.1 and later.
workspace_size
+=
lhs_scale_inv_aval
.
size
*
tensor_scaling_sinv_aligment
workspace_size
+=
rhs_scale_inv_aval
.
size
*
tensor_scaling_sinv_aligment
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
# We also pad scale_inv swizzle buffers size for 256 bytes alignment.
workspace_size
+=
lhs_scale_inv_aval
.
size
+
mxfp8_scaling_sinv_alignment_padding
workspace_size
+=
rhs_scale_inv_aval
.
size
+
mxfp8_scaling_sinv_alignment_padding
workspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
workspace_size
,),
dtype
=
jnp
.
uint8
)
# TODO(phuong): We should make separate tmp buffers for swizzled scales to avoid unaligned-by-256 workspace ptr issue
out_shape
=
(
M
,
N
)
if
is_grouped_dense_wgrad
:
...
...
@@ -221,11 +1161,8 @@ def _shape_normalization(x, dimension_numbers, already_transposed: bool = False)
def
_calculate_remaining_shape
(
shape
,
contracting_dims
):
return
tuple
(
shape
[
dim
]
for
dim
in
range
(
len
(
shape
))
if
dim
not
in
contracting_dims
)
def
_transpose_contract_dims
(
ndim
,
contracting_dims
):
return
tuple
(
ndim
-
i
-
1
for
i
in
contracting_dims
)[::
-
1
]
contracting_dims_
=
sanitize_dims
(
len
(
shape
),
contracting_dims
)
return
tuple
(
shape
[
dim
]
for
dim
in
range
(
len
(
shape
))
if
dim
not
in
contracting_dims_
)
# Apply jit to guarantee correctness of FP8 GEMM.
...
...
@@ -233,9 +1170,11 @@ def _transpose_contract_dims(ndim, contracting_dims):
def
_jax_gemm_tensor_scaling_fp8
(
lhs
,
rhs
,
dim_nums
,
precision
):
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
=
dim_nums
if
lhs
.
data_layout
==
"T"
:
lhs_contract
=
_transpose_contract_dims
(
lhs
.
data
.
ndim
,
lhs_contract
)
lhs_contract
=
transpose_dims
(
lhs
.
data
.
ndim
,
lhs_contract
,
flatten_axis
=
lhs
.
flatten_axis
)
lhs_batch
=
transpose_dims
(
lhs
.
data
.
ndim
,
lhs_batch
,
flatten_axis
=
lhs
.
flatten_axis
)
if
rhs
.
data_layout
==
"T"
:
rhs_contract
=
_transpose_contract_dims
(
rhs
.
data
.
ndim
,
rhs_contract
)
rhs_contract
=
transpose_dims
(
rhs
.
data
.
ndim
,
rhs_contract
,
flatten_axis
=
rhs
.
flatten_axis
)
rhs_batch
=
transpose_dims
(
rhs
.
data
.
ndim
,
rhs_batch
,
flatten_axis
=
rhs
.
flatten_axis
)
dim_nums
=
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
...
...
@@ -280,10 +1219,6 @@ def _jax_gemm_mxfp8_1d(
lhs_scale_3d
=
_shape_normalization
(
lhs
.
scale_inv
,
(
lhs_contract
,
lhs_batch
))
rhs_scale_3d
=
_shape_normalization
(
rhs
.
scale_inv
,
(
rhs_contract
,
rhs_batch
))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d
=
jnp
.
asarray
(
lhs_scale_3d
[:,
:
lhs_3d
.
shape
[
1
],
:
int
(
lhs_3d
.
shape
[
2
]
/
32
)])
rhs_scale_3d
=
jnp
.
asarray
(
rhs_scale_3d
[:,
:
rhs_3d
.
shape
[
1
],
:
int
(
rhs_3d
.
shape
[
2
]
/
32
)])
# JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
...
...
@@ -306,12 +1241,12 @@ def _jax_gemm(
lhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
rhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
quantizer_set
:
Dict
[
"str"
,
Quantizer
]
=
noop_quantizer_set
,
lhs_quantizer
:
Quantizer
=
None
,
rhs_quantizer
:
Quantizer
=
None
,
)
->
jnp
.
ndarray
:
"""
FP8 GEMM via JAX
"""
dim_nums
=
(
contracting_dims
,
((),
()))
def
_jax_gemm_fp8_impl
(
lhs
,
rhs
):
...
...
@@ -331,32 +1266,16 @@ def _jax_gemm(
raise
NotImplementedError
(
"Unsupported ScalingMode: {lhs.scaling_mode}"
)
if
isinstance
(
lhs
,
ScaledTensor
)
and
isinstance
(
rhs
,
ScaledTensor
):
return
_jax_gemm_fp8_impl
(
lhs
,
rhs
)
if
not
isinstance
(
lhs
,
ScaledTensor
)
and
not
isinstance
(
rhs
,
ScaledTensor
):
if
quantizer_set
!=
noop_quantizer_set
:
assert
type
(
quantizer_set
.
x
)
is
type
(
quantizer_set
.
kernel
)
(((
lhs_contract_dim
,),
(
rhs_contract_dim
,)),
_
)
=
dim_nums
lhs_is_rowwise
=
lhs_contract_dim
==
lhs
.
ndim
-
1
rhs_is_rowwise
=
rhs_contract_dim
==
rhs
.
ndim
-
1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q
=
quantizer_set
.
x
.
quantize
(
lhs
,
is_rowwise
=
lhs_is_rowwise
,
is_colwise
=
not
lhs_is_rowwise
,
)
rhs_q
=
quantizer_set
.
kernel
.
quantize
(
rhs
,
is_rowwise
=
rhs_is_rowwise
,
is_colwise
=
not
rhs_is_rowwise
,
)
return
_jax_gemm_fp8_impl
(
lhs_q
,
rhs_q
)
lhs_q
,
rhs_q
=
_quantize_gemm_operands
(
lhs
,
rhs
,
lhs_quantizer
,
rhs_quantizer
,
contracting_dims
)
if
isinstance
(
lhs_q
,
ScaledTensor
)
and
isinstance
(
rhs_q
,
ScaledTensor
):
return
_jax_gemm_fp8_impl
(
lhs_q
,
rhs_q
)
if
(
isinstance
(
lhs
,
jnp
.
ndarray
)
and
isinstance
(
rhs
,
jnp
.
ndarray
)
and
quantizer_set
==
noop_quantizer_set
and
lhs_quantizer
is
None
and
rhs_quantizer
is
None
):
return
jax
.
lax
.
dot_general
(
lhs
,
rhs
,
dim_nums
,
preferred_element_type
=
lhs
.
dtype
)
...
...
@@ -366,30 +1285,109 @@ def _jax_gemm(
def
gemm
(
lhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
rhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
)
->
jnp
.
ndarray
:
"""General matrix multiplication with optional quantization.
Args:
lhs: First input matrix.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions.
The first sequence represents the contracting dimensions of the first matrix,
and the second sequence represents the contracting dimensions of the second matrix.
quantizer_set: Set of quantizers for FP8 quantization of the output.
If None, no quantization is applied and the output has the same dtype as the inputs.
Returns:
If quantizer_set is None:
The matrix multiplication result.
Shape: (M, N)
Dtype: Same as input dtype
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
-
1
,),
(
0
,)),
batched_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((),
()),
lhs_quantizer
:
Quantizer
=
None
,
rhs_quantizer
:
Quantizer
=
None
,
**
kwargs
,
)
->
Tuple
[
jnp
.
ndarray
,
...]:
r
"""General matrix multiplication with optional quantization.
Parameters
----------
lhs: Union[jax.Array, ScaledTensor]
Left-hand side operand in the matrix multiplication.
rhs: Union[jax.Array, ScaledTensor]
Right-hand side operand in the matrix multiplication.
lhs_quantizer: Quantizer, default = None
Object for down-casting the LHS operand for quantized GEMM.
rhs_quantizer: Quantizer, default = None
Object for down-casting the RHS operand for quantized GEMM.
contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, ))
Tuple of sequences representing the contracting dimensions of the operands.
batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()),
Tuple of sequences representing the batched dimensions of the operands. This is *not* used
to perform a batched matrix multiplication, but it is required to avoid a potentially
undesirable reduction in any batched contracting dimensions when invoked with sharded
operands (e.g. when computing weight gradients in a Flax module).
bias: jax.Array, default = None
Optional additive bias term, required for forward GEMM with bias fusion. Only supported
with TE's custom call to cuBLAS GEMM.
gelu_input: jax.Array, default = None
Pre-GeLU output from forward GEMM, required for backward/grad GEMM with dGeLU fusion. Only
supported with TE's custom call to cuBLAS GEMM.
fuse_bias: bool, default = False
Enable bias addition in forward GEMM or bias gradient in backward GEMM. Only supported with
TE's custom call to cuBLAS GEMM.
fuse_gelu: bool, default = False
Enable GeLU activation in forward GEMM or GeLU gradient in backward GEMM. Only supported
with TE's custom call to cuBLAS GEMM.
grad: bool, default = False
Flag for switching bias and GeLU fusions from forward to backward mode. Only supported with
TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
Returns
-------
jax.Array:
Result of the operation. For TE's custom call to cuBLAS GEMM, this result can include the
GeLU application when `fuse_gelu=True` and `grad=False`, the GeLU gradient contribution
when `fuse_gelu=True` and `grad=True`, and the additive bias when `fuse_bias=True` and
`grad=False`.
Optional[jax.Array]:
Bias gradient when `fuse_bias=True` and `grad=True`. Only supported with TE's custom call
to cuBLAS GEMM.
Optional[jax.Array]:
Pre-GeLU GEMM output when `fuse_gelu=True` and `grad=False`. This is required as an input
to `_te_gemm()` with `fuse_gelu=True` and `grad=True` in the backward pass in order to
compute the GeLU contribution to the gradient. Only supported with TE's custom call to
cuBLAS GEMM.
"""
# Try to get LHS and RHS quantizers from a quantizer set for backward compatibility
if
lhs_quantizer
is
None
or
rhs_quantizer
is
None
:
quantizer_set
=
kwargs
.
get
(
"quantizer_set"
,
None
)
if
quantizer_set
is
not
None
:
lhs_quantizer
=
quantizer_set
.
x
rhs_quantizer
=
quantizer_set
.
kernel
# Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
fuse_bias
=
kwargs
.
get
(
"fuse_bias"
,
False
)
fuse_gelu
=
kwargs
.
get
(
"fuse_gelu"
,
False
)
if
not
GemmPrimitive
.
enabled
():
assert
kwargs
.
get
(
"bias"
,
None
)
is
None
and
not
fuse_gelu
,
(
"TE GEMM was invoked with bias fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
assert
kwargs
.
get
(
"gelu_input"
,
None
)
is
None
and
not
fuse_bias
,
(
"TE GEMM was invoked with GeLU fusion options that are not supported by the "
"`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled."
)
return
_jax_gemm
(
lhs
,
rhs
,
contracting_dims
,
lhs_quantizer
,
rhs_quantizer
)
outputs
=
_te_gemm
(
lhs
,
rhs
,
lhs_quantizer
=
lhs_quantizer
,
rhs_quantizer
=
rhs_quantizer
,
contracting_dims
=
contracting_dims
,
batched_dims
=
batched_dims
,
**
kwargs
,
)
return
_jax_gemm
(
lhs
,
rhs
,
contracting_dims
,
quantizer_set
)
# Discard empty outputs
grad
=
kwargs
.
get
(
"grad"
,
False
)
clean_outputs
=
outputs
[
0
]
# first output is the final result and is never empty
if
(
fuse_bias
and
grad
)
or
(
fuse_gelu
and
not
grad
):
clean_outputs
=
(
outputs
[
0
],)
if
fuse_bias
and
grad
:
# only return bias gradient if it exists
clean_outputs
+=
(
outputs
[
1
],)
if
fuse_gelu
and
not
grad
:
# only return pre-GeLU output if it exists
clean_outputs
+=
(
outputs
[
2
],)
return
clean_outputs
def
grouped_gemm
(
...
...
@@ -490,15 +1488,13 @@ def grouped_gemm(
assert
type
(
quantizer_set
.
x
)
is
type
(
quantizer_set
.
kernel
)
scaling_mode
=
quantizer_set
.
x
.
scaling_mode
if
(
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
# scaling_mode.is_tensor_scaling()
# and is_gemm_with_all_layouts_supported()
scaling_mode
.
is_1d_block_scaling
()
quantizer_set
.
x
.
scaling_mode
.
is_tensor_scaling
()
and
is_fp8_gemm_with_all_layouts_supported
()
):
lhs_is_rowwise
=
rhs_is_rowwise
=
True
else
:
lhs_is_rowwise
=
not
lhs_is_trans
rhs_is_rowwise
=
l
hs_is_trans
rhs_is_rowwise
=
r
hs_is_trans
quantizer_set
.
x
.
q_layout
=
(
QuantizeLayout
.
ROWWISE
if
lhs_is_rowwise
else
QuantizeLayout
.
COLWISE
)
...
...
@@ -513,6 +1509,8 @@ def grouped_gemm(
rhs_data
=
rhs_q
.
data
lhs_scale_inv
=
lhs_q
.
scale_inv
rhs_scale_inv
=
rhs_q
.
scale_inv
lhs_shape
=
lhs_q
.
original_shape
rhs_shape
=
rhs_q
.
original_shape
assert
not
(
lhs_data
.
dtype
==
jnp
.
float8_e5m2
and
rhs_data
.
dtype
==
jnp
.
float8_e5m2
...
...
@@ -520,24 +1518,35 @@ def grouped_gemm(
# Only support FP8 GEMM with NT layout on Hopper and other earlier GPUs
# thus additional transpose is required
# TODO(Phuong): we force Blackwell to also use NT layout for now, need to fix later
if
scaling_mode
.
is_tensor_scaling
():
# and not is_gemm_with_all_layouts_supported():
lhs_is_trans
=
False
rhs_is_trans
=
True
if
scaling_mode
.
is_tensor_scaling
()
and
not
is_fp8_gemm_with_all_layouts_supported
():
if
isinstance
(
lhs
,
ScaledTensor
)
and
isinstance
(
rhs
,
ScaledTensor
):
lhs_layout_is_T
=
lhs
.
data_layout
==
"T"
rhs_layout_is_T
=
rhs
.
data_layout
==
"T"
else
:
lhs_layout_is_T
=
lhs_q
.
data_layout
==
"T"
rhs_layout_is_T
=
rhs_q
.
data_layout
==
"T"
# we can't apply _shape_normalization on the grouped input
# thus we need to ensure that lhs is in N and rhs is in T
assert
(
lhs_is_trans
==
lhs_layout_is_T
),
"lhs input must be transposed before calling grouped_gemm"
assert
(
not
rhs_is_trans
==
rhs_layout_is_T
),
"rhs input must be transposed before calling grouped_gemm"
lhs_is_trans
=
False
rhs_is_trans
=
True
lhs_ndim
=
len
(
lhs_shape
)
rhs_ndim
=
len
(
rhs_shape
)
if
lhs_layout_is_T
:
lhs_contract_dim
=
tuple
((
lhs_ndim
-
1
-
i
)
%
lhs_ndim
for
i
in
lhs_contract_dim
)
if
rhs_layout_is_T
:
rhs_contract_dim
=
tuple
((
rhs_ndim
-
1
-
i
)
%
rhs_ndim
for
i
in
rhs_contract_dim
)
lhs_data
=
_shape_normalization
(
lhs_data
,
(
lhs_contract_dim
,
()),
not
lhs_layout_is_T
)
rhs_data
=
_shape_normalization
(
rhs_data
,
(
rhs_contract_dim
,
()),
rhs_layout_is_T
)
# For rhs [G, K, N], need to exclude the G dim from contract_dim
if
group_sizes
.
size
==
rhs_shape
[
0
]:
rhs_contract_dim
=
tuple
(
(
rhs_ndim
-
1
-
i
)
%
(
rhs_ndim
-
1
)
+
1
for
i
in
rhs_contract_dim
)
else
:
rhs_contract_dim
=
tuple
((
rhs_ndim
-
1
-
i
)
%
rhs_ndim
for
i
in
rhs_contract_dim
)
# Calling GroupedGEMM Custom Call
K_lhs
=
math
.
prod
(
lhs_shape
[
i
]
for
i
in
lhs_contract_dim
)
...
...
@@ -557,9 +1566,6 @@ def grouped_gemm(
assert
not
has_bias
or
bias
.
shape
==
(
group_sizes
.
size
,
N
)
bias
=
jnp
.
empty
((),
jnp
.
float32
)
if
bias
is
None
else
bias
# TODO(Phuong): support MXFP8_1D_SCALING
assert
scaling_mode
!=
ScalingMode
.
MXFP8_1D_SCALING
,
"MXFP8_1D_SCALING is not yet supported"
(
out
,)
=
GroupedGemmPrimitive
.
outer_primitive
.
bind
(
lhs_data
,
lhs_scale_inv
,
...
...
transformer_engine/jax/cpp_extensions/misc.py
View file @
44740c6c
...
...
@@ -198,14 +198,19 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied.
"""
if
quantizer
is
None
:
return
False
arch_l_100
=
False
for
local_gpu_id
in
range
(
len
(
jax
.
local_devices
())):
if
transformer_engine_jax
.
get_device_compute_capability
(
local_gpu_id
)
<
100
:
arch_l_100
=
True
break
# _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE,
# but this fails when bias fusion is turned on with arch < 100.
force_1x_quantization
=
quantizer
.
scaling_mode
.
is_tensor_scaling
()
and
quantizer
.
is_2x2x
()
return
(
quantizer
is
not
None
and
quantizer
.
q_layout
==
QuantizeLayout
.
ROWWISE
(
force_1x_quantization
or
quantizer
.
q_layout
==
QuantizeLayout
.
ROWWISE
)
and
arch_l_100
and
is_dbias
)
...
...
transformer_engine/jax/cpp_extensions/normalization.py
View file @
44740c6c
...
...
@@ -587,16 +587,17 @@ class NormFwdPrimitive(BasePrimitive):
result_types
,
)
prefix
=
"NormFwdPrimitive_"
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
len
(
value_types
[
0
].
shape
),
unique_var
=
"NormFwdPrimitive_i
"
,
flatten_axis
=-
1
len
(
value_types
[
0
].
shape
),
unique_var
=
prefix
+
"x
"
,
flatten_axis
=-
1
)
x_axes
=
scale_rules
.
input_spec
out
=
x_axes
[:
-
1
]
+
(
"k"
,)
colwise_out
=
out
if
is_2x
else
(
"…4
"
,)
out
=
x_axes
colwise_out
=
out
if
is_2x
else
(
prefix
+
"out_colwise
"
,)
rsigma
=
x_axes
[:
-
1
]
mu
=
(
"…5
"
,)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
else
rsigma
amax
=
(
"…6
"
,)
mu
=
(
prefix
+
"mu
"
,)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
else
rsigma
amax
=
(
prefix
+
"amax
"
,)
return
SdyShardingRule
(
(
x_axes
,
(
"…1"
,),
(
"…2"
,),
(
"…3"
,)),
...
...
@@ -609,7 +610,6 @@ class NormFwdPrimitive(BasePrimitive):
mu
,
rsigma
,
),
**
scale_rules
.
factor_sizes
,
)
...
...
@@ -1276,6 +1276,7 @@ def normalization_fwd(
epsilon
:
float
,
norm_type
:
str
,
quantizer
:
Optional
[
Quantizer
],
noop_scaled_tensor
:
bool
=
False
,
):
"""Common wrapper for normalization forward pass.
...
...
@@ -1292,6 +1293,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None.
Returns:
A tuple containing:
...
...
@@ -1319,6 +1321,15 @@ def normalization_fwd(
else
:
raise
ValueError
(
f
"
{
norm_type
=
}
is not supported."
)
if
quantizer
is
None
and
noop_scaled_tensor
:
return
(
ScaledTensorFactory
.
create_2x
(
output
,
None
,
output
,
None
,
ScalingMode
.
NO_SCALING
,
dq_dtype
=
output
.
dtype
),
mu
,
rsigma
,
)
return
output
,
mu
,
rsigma
...
...
transformer_engine/jax/cpp_extensions/quantization.py
View file @
44740c6c
...
...
@@ -36,7 +36,6 @@ from ..quantize import (
Quantizer
,
GroupedQuantizer
,
QuantizeLayout
,
DelayedScaleQuantizer
,
ScalingMode
,
compute_scale_from_amax
,
)
...
...
@@ -489,9 +488,10 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
):
del
out_dtype
,
scale_dtype
,
is_outer
,
mesh
,
result_types
prefix
=
"BaseDBiasQuantizePrimitive_"
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
len
(
value_types
[
0
].
shape
),
unique_var
=
"BaseDBiasQuantizePrimitive_i
"
,
unique_var
=
prefix
+
"x
"
,
flatten_axis
=
flatten_axis
,
)
...
...
@@ -499,22 +499,19 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
colwise_scale_inv
=
scale_rules
.
colwise_rule
out
=
x_axes
colwise_out
=
(
prefix
+
"out_colwise"
,)
if
q_layout
in
(
QuantizeLayout
.
COLWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
if
ScalingMode
(
scaling_mode
).
is_tensor_scaling
():
colwise_out
=
tuple
(
multidim_transpose
(
x_axes
,
transpose_axis
=
flatten_axis
))
else
:
colwise_out
=
x_axes
else
:
colwise_out
=
(
"j"
,)
colwise_scale_inv
=
(
"k"
,)
dbias
=
x_axes
[
flatten_axis
:]
if
is_dbias
else
(
"l
"
,)
amax
=
(
"m
"
,)
dbias
=
x_axes
[
flatten_axis
:]
if
is_dbias
else
(
prefix
+
"dbias
"
,)
amax
=
(
prefix
+
"amax
"
,)
return
SdyShardingRule
(
(
x_axes
,
(
"…1"
,)),
(
out
,
colwise_out
,
scale_rules
.
rowwise_rule
,
colwise_scale_inv
,
amax
,
dbias
),
**
scale_rules
.
factor_sizes
,
)
...
...
@@ -538,11 +535,12 @@ def _jax_quantize(
def
_jax_dbias
(
dx
:
jnp
.
ndarray
,
dtype
=
None
,
flatten_axis
:
int
=
-
1
):
assert
flatten_axis
<
0
sum_axis
=
dx
.
ndim
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
assert
sum_axis
<
dx
.
ndim
,
"Flatten axis out of bounds!"
dtype
=
dtype
or
dx
.
dtype
dbias
=
jnp
.
sum
(
dx
.
astype
(
jnp
.
float32
),
axis
=
tuple
(
range
(
dx
.
ndim
+
flatten
_axis
)),
axis
=
tuple
(
range
(
sum
_axis
)),
keepdims
=
False
,
)
return
dbias
.
astype
(
dtype
)
...
...
@@ -568,6 +566,7 @@ def _quantize_dbias_impl(
is_dbias
:
bool
=
False
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
flatten_axis
:
int
=
-
1
,
noop_scaled_tensor
:
bool
=
False
,
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
"""
Cast wrapper
...
...
@@ -577,24 +576,34 @@ def _quantize_dbias_impl(
quantizer
is
not
None
),
"quantizer must be provided if dq_dtype is provided"
# Early-exit for non-quantized call
dq_dtype
=
dq_dtype
or
x
.
dtype
PrimitiveClass
=
DBiasQuantizePrimitive
if
is_dbias
else
QuantizePrimitive
if
not
PrimitiveClass
.
enabled
():
if
quantizer
is
None
:
dbias
=
None
if
is_dbias
:
return
_jax_quantize_dbias
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
dbias
=
_jax_dbias
(
x
,
dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
if
noop_scaled_tensor
:
# Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor()
# always works.
return
(
ScaledTensorFactory
.
create_2x
(
x
,
None
,
x
,
None
,
ScalingMode
.
NO_SCALING
,
dq_dtype
=
x
.
dtype
,
data_layout
=
"NN"
,
flatten_axis
=
flatten_axis
,
),
dbias
,
)
return
(
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
),
None
,
)
return
x
,
dbias
# TE/common doesn't support colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_layout
==
QuantizeLayout
.
COLWISE
:
# If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE,
# fall back on the native-JAX quantize implementation
PrimitiveClass
=
DBiasQuantizePrimitive
if
is_dbias
else
QuantizePrimitive
if
quantizer
.
q_layout
==
QuantizeLayout
.
COLWISE
or
not
PrimitiveClass
.
enabled
():
if
is_dbias
:
return
_jax_quantize_dbias
(
x
,
...
...
@@ -606,9 +615,8 @@ def _quantize_dbias_impl(
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
),
None
,
)
scale
=
jnp
.
empty
((),
jnp
.
float32
)
# TE/common
dbias_
quantize does not support
1x
on arch < 100
# TE/common
custom
quantize
op
does not support
dbias fusion with 1x quantization
on arch < 100
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
out
,
_
=
_quantize_dbias_impl
(
x
=
x
,
...
...
@@ -620,29 +628,23 @@ def _quantize_dbias_impl(
dbias
=
_jax_dbias
(
x
,
dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
out
,
dbias
if
quantizer
is
None
:
if
is_dbias
:
return
x
,
_jax_dbias
(
x
,
dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
x
,
None
scale
=
jnp
.
empty
((),
jnp
.
float32
)
if
quantizer
.
scaling_mode
==
ScalingMode
.
CURRENT_TENSOR_SCALING
:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax
=
jnp
.
amax
(
jnp
.
abs
(
x
),
keepdims
=
True
).
astype
(
jnp
.
float32
)
scale
=
compute_scale_from_amax
(
amax
,
quantizer
.
q_dtype
)
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
elif
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
scale
=
quantizer
.
scale
is_1x_kernel_supported
=
not
(
is_dbias
and
get_min_device_compute_capability
()
<
100
)
# It is faster to use 1x quantization for tensor scaling
is_1x_kernel_supported
=
not
(
is_dbias
and
get_min_device_compute_capability
()
<
100
)
force_1x_quantization
=
(
quantizer
.
scaling_mode
.
is_tensor_scaling
()
and
quantizer
.
is_2x2x
()
and
is_1x_kernel_supported
)
q_layout
=
quantizer
.
q_layout
if
force_1x_quantization
:
q_layout
=
QuantizeLayout
.
ROWWISE
...
...
@@ -698,6 +700,7 @@ def quantize(
x
:
jnp
.
ndarray
,
quantizer
:
Quantizer
,
flatten_axis
:
int
=
-
1
,
noop_scaled_tensor
:
bool
=
False
,
)
->
Tuple
[
ScaledTensor
]:
"""Quantize input tensor according to the quantizer.
...
...
@@ -707,6 +710,8 @@ def quantize(
quantizer: Quantizer for FP8 quantization of the output.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer
is None.
Returns:
A ScaledTensor containing the quantized input tensor.
...
...
@@ -715,6 +720,7 @@ def quantize(
x
,
quantizer
=
quantizer
,
flatten_axis
=
flatten_axis
,
noop_scaled_tensor
=
noop_scaled_tensor
,
)
return
out
...
...
@@ -724,6 +730,7 @@ def quantize_dbias(
quantizer
:
Quantizer
,
is_dbias
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
noop_scaled_tensor
:
bool
=
False
,
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
"""Quantize input tensor and compute bias gradient.
...
...
@@ -734,6 +741,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1.
noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when
quantizer is None.
Returns:
A tuple containing:
...
...
@@ -743,7 +752,11 @@ def quantize_dbias(
Shape: (K,) or empty if is_dbias is False.
"""
return
_quantize_dbias_impl
(
dz
,
quantizer
=
quantizer
,
is_dbias
=
is_dbias
,
flatten_axis
=
flatten_axis
dz
,
quantizer
=
quantizer
,
is_dbias
=
is_dbias
,
flatten_axis
=
flatten_axis
,
noop_scaled_tensor
=
noop_scaled_tensor
,
)
...
...
transformer_engine/jax/csrc/extensions.h
View file @
44740c6c
...
...
@@ -119,6 +119,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
// GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GemmHandler
);
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GroupedGemmHandler
);
...
...
transformer_engine/jax/csrc/extensions/ffi.cpp
View file @
44740c6c
...
...
@@ -38,12 +38,11 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case
xla
::
ffi
::
DataType
::
F8E4M3FN
:
return
DType
::
kFloat8E4M3
;
break
;
//
case xla::ffi::DataType::F8E8M0FNU:
//
return DType::kFloat8E8M0;
//
break;
case
xla
::
ffi
::
DataType
::
F8E8M0FNU
:
return
DType
::
kFloat8E8M0
;
break
;
default:
auto
type_num
=
static_cast
<
XLA_FFI_DataType
>
(
type
);
if
(
type_num
==
33
)
return
DType
::
kFloat8E8M0
;
NVTE_ERROR
(
"TE does not support conversion of XLA_FFI_DataType %d"
,
static_cast
<
int
>
(
type_num
));
break
;
...
...
transformer_engine/jax/csrc/extensions/gemm.cpp
View file @
44740c6c
...
...
@@ -6,10 +6,14 @@
#include "transformer_engine/gemm.h"
#include <memory>
#include <string_view>
#include <tuple>
#include "../extensions.h"
#include "common/util/cuda_runtime.h"
#include "common/util/string.h"
#include "common/util/system.h"
#include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h"
#define MXFP8_BLOCK_SIZE 32
...
...
@@ -17,6 +21,187 @@
namespace
transformer_engine
{
namespace
jax
{
static
uint8_t
*
move_ptr_to_next_256B_aligned
(
uint8_t
*
ptr
)
{
// Move the pointer to the next 256B aligned address
return
reinterpret_cast
<
uint8_t
*>
((
reinterpret_cast
<
uintptr_t
>
(
ptr
)
+
255
)
&
~
static_cast
<
uintptr_t
>
(
255
));
}
std
::
tuple
<
TensorWrapper
,
std
::
vector
<
size_t
>>
xla_buffer_to_nvte_gemm_operand
(
cudaStream_t
stream
,
Buffer_Type
buffer
,
Buffer_Type
scale_inv
,
Result_Type
swizzled_scale_inv
,
JAXX_Scaling_Mode
scaling_mode
,
size_t
axis_boundary
,
bool
rowwise
)
{
// Set tensor data with collapsed 2D shape
auto
buffer_dims
=
buffer
.
dimensions
();
std
::
vector
<
size_t
>
input_shape
=
{
product
(
buffer_dims
,
0
,
axis_boundary
),
product
(
buffer_dims
,
axis_boundary
,
buffer_dims
.
size
())};
auto
input_dtype
=
convert_ffi_datatype_to_te_dtype
(
buffer
.
element_type
());
TensorWrapper
input
(
get_nvte_scaling_mode
(
scaling_mode
));
if
(
rowwise
)
{
input
.
set_rowwise_data
(
buffer
.
untyped_data
(),
input_dtype
,
input_shape
);
}
else
{
input
.
set_columnwise_data
(
buffer
.
untyped_data
(),
input_dtype
,
input_shape
);
}
// Set scaling factor for quantized tensors
if
(
scaling_mode
!=
JAXX_Scaling_Mode
::
NO_SCALING
)
{
NVTE_CHECK
(
typeToSize
(
input_dtype
)
==
1
,
"Quantized GEMM requires 8-bit operands."
);
NVTE_CHECK
(
scale_inv
.
element_count
()
>
0
,
"Missing inverse scaling factor for quantized GEMM."
);
std
::
vector
<
size_t
>
scale_shape
=
{
1
};
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
{
// Block scaling also needs to be collapsed to match 2D data
scale_shape
=
{
product
(
scale_inv
.
dimensions
(),
0
,
axis_boundary
),
product
(
scale_inv
.
dimensions
(),
axis_boundary
,
scale_inv
.
dimensions
().
size
())};
}
auto
scale_dtype
=
convert_ffi_datatype_to_te_dtype
(
scale_inv
.
element_type
());
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
scale_inv
.
untyped_data
(),
scale_dtype
,
scale_shape
);
}
else
{
input
.
set_columnwise_scale_inv
(
scale_inv
.
untyped_data
(),
scale_dtype
,
scale_shape
);
}
// Swizzle scaling factors for MXFP8
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
{
// Get the swizzle buffer
NVTE_CHECK
(
swizzled_scale_inv
->
element_count
()
>
0
,
"Missing swizzled inverse scale buffer in the JAX primitive."
);
auto
scale_inv_dtype
=
convert_ffi_datatype_to_te_dtype
(
scale_inv
.
element_type
());
auto
swizzled_scale_inv_dtype
=
convert_ffi_datatype_to_te_dtype
(
swizzled_scale_inv
->
element_type
());
NVTE_CHECK
(
typeToSize
(
scale_inv_dtype
)
==
1
&&
typeToSize
(
swizzled_scale_inv_dtype
)
==
1
,
"Inverse scale factors need to have an 8-bit data type."
);
// Create tensor to hold swizzled scale factor
TensorWrapper
output
(
get_nvte_scaling_mode
(
scaling_mode
));
if
(
rowwise
)
{
output
.
set_rowwise_data
(
buffer
.
untyped_data
(),
input_dtype
,
input_shape
);
output
.
set_rowwise_scale_inv
(
swizzled_scale_inv
->
untyped_data
(),
scale_dtype
,
scale_shape
);
}
else
{
output
.
set_columnwise_data
(
buffer
.
untyped_data
(),
input_dtype
,
input_shape
);
output
.
set_columnwise_scale_inv
(
swizzled_scale_inv
->
untyped_data
(),
scale_dtype
,
scale_shape
);
}
// Launch swizzle kernel
nvte_swizzle_scaling_factors
(
input
.
data
(),
output
.
data
(),
stream
);
// Set swizzled scales into the input tensor
if
(
rowwise
)
{
input
.
set_rowwise_scale_inv
(
swizzled_scale_inv
->
untyped_data
(),
scale_dtype
,
scale_shape
);
}
else
{
input
.
set_columnwise_scale_inv
(
swizzled_scale_inv
->
untyped_data
(),
scale_dtype
,
scale_shape
);
}
}
}
return
std
::
make_tuple
(
std
::
move
(
input
),
input_shape
);
}
Error_Type
GemmFFI
(
cudaStream_t
stream
,
Buffer_Type
lhs
,
Buffer_Type
lhs_scale_inv
,
Buffer_Type
rhs
,
Buffer_Type
rhs_scale_inv
,
Buffer_Type
bias
,
Buffer_Type
gelu_input
,
Result_Type
output
,
Result_Type
bias_grad
,
Result_Type
pre_gelu_out
,
Result_Type
lhs_swizzle
,
Result_Type
rhs_swizzle
,
Result_Type
workspace
,
JAXX_Scaling_Mode
scaling_mode
,
int64_t
lhs_axis_boundary
,
int64_t
rhs_axis_boundary
,
bool
lhs_transposed
,
bool
rhs_transposed
,
bool
fuse_bias
,
bool
fuse_gelu
,
bool
grad
,
bool
use_split_accumulator
)
{
// Operands (this includes swizzling MXFP8 scaling factors)
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool
always_rowwise
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
NO_SCALING
||
(
is_tensor_scaling
(
scaling_mode
)
&&
nvte_is_non_tn_fp8_gemm_supported
()));
bool
make_lhs_rowwise
=
(
always_rowwise
)
?
true
:
!
lhs_transposed
;
bool
make_rhs_rowwise
=
(
always_rowwise
)
?
true
:
rhs_transposed
;
auto
[
lhs_
,
lhs_shape
]
=
xla_buffer_to_nvte_gemm_operand
(
stream
,
lhs
,
lhs_scale_inv
,
lhs_swizzle
,
scaling_mode
,
lhs_axis_boundary
,
make_lhs_rowwise
);
auto
[
rhs_
,
rhs_shape
]
=
xla_buffer_to_nvte_gemm_operand
(
stream
,
rhs
,
rhs_scale_inv
,
rhs_swizzle
,
scaling_mode
,
rhs_axis_boundary
,
make_rhs_rowwise
);
// Output tensor
std
::
vector
<
size_t
>
out_shape
=
{(
lhs_transposed
)
?
lhs_shape
[
1
]
:
lhs_shape
[
0
],
(
rhs_transposed
)
?
rhs_shape
[
0
]
:
rhs_shape
[
1
]};
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output
->
element_type
());
auto
out_
=
TensorWrapper
(
output
->
untyped_data
(),
out_shape
,
out_dtype
);
NVTE_CHECK
(
out_
.
numel
()
==
output
->
element_count
(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected "
,
out_
.
numel
(),
" elements "
,
to_string_like
(
out_shape
),
" but got "
,
output
->
element_count
(),
" elements "
,
to_string_like
(
output
->
dimensions
()));
// Bias input to forward pass or bias gradient output from backward pass
void
*
bias_ptr
=
nullptr
;
std
::
vector
<
size_t
>
bias_shape
=
{
0
};
DType
bias_dtype
=
out_dtype
;
if
(
fuse_bias
)
{
if
(
!
grad
)
{
NVTE_CHECK
(
bias_grad
->
untyped_data
()
==
bias
.
untyped_data
(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"
);
}
bias_ptr
=
bias_grad
->
untyped_data
();
bias_shape
.
at
(
0
)
=
bias_grad
->
dimensions
().
front
();
bias_dtype
=
convert_ffi_datatype_to_te_dtype
(
bias_grad
->
element_type
());
}
auto
bias_
=
TensorWrapper
(
bias_ptr
,
bias_shape
,
bias_dtype
);
// Pre-GeLU output from forward pass or input to backward pass
void
*
pre_gelu_ptr
=
nullptr
;
std
::
vector
<
size_t
>
pre_gelu_shape
=
{
0
};
DType
pre_gelu_dtype
=
out_dtype
;
if
(
gelu_input
.
element_count
()
>
0
)
{
if
(
grad
)
{
NVTE_CHECK
(
pre_gelu_out
->
untyped_data
()
==
gelu_input
.
untyped_data
(),
"Missing operand-output aliasing in GemmPrimitive: gelu_input <-> pre_gelu_out"
);
}
pre_gelu_ptr
=
pre_gelu_out
->
untyped_data
();
pre_gelu_shape
=
{
product
(
pre_gelu_out
->
dimensions
(),
0
,
pre_gelu_out
->
dimensions
().
size
()
-
1
),
static_cast
<
size_t
>
(
pre_gelu_out
->
dimensions
().
back
())};
pre_gelu_dtype
=
convert_ffi_datatype_to_te_dtype
(
pre_gelu_out
->
element_type
());
}
auto
pre_gelu_
=
TensorWrapper
(
pre_gelu_ptr
,
pre_gelu_shape
,
pre_gelu_dtype
);
// cuBLAS workspace + 256 alignment enforcement
auto
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
->
untyped_data
());
workspace_ptr
=
move_ptr_to_next_256B_aligned
(
workspace_ptr
);
std
::
vector
<
size_t
>
workspace_shape
=
{
static_cast
<
size_t
>
(
workspace
->
element_count
())
-
256
};
auto
workspace_
=
TensorWrapper
(
workspace_ptr
,
workspace_shape
,
DType
::
kByte
);
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto
num_math_sm
=
cuda
::
sm_count
()
-
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
0
);
nvte_cublas_gemm
(
rhs_
.
data
(),
lhs_
.
data
(),
out_
.
data
(),
bias_
.
data
(),
pre_gelu_
.
data
(),
rhs_transposed
,
lhs_transposed
,
grad
,
workspace_
.
data
(),
false
,
use_split_accumulator
,
num_math_sm
,
stream
);
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
GemmHandler
,
GemmFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// lhs
.
Arg
<
Buffer_Type
>
()
// lhs_scale_inv
.
Arg
<
Buffer_Type
>
()
// rhs
.
Arg
<
Buffer_Type
>
()
// rhs_scale_inv
.
Arg
<
Buffer_Type
>
()
// bias
.
Arg
<
Buffer_Type
>
()
// gelu_input
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// bias_grad
.
Ret
<
Buffer_Type
>
()
// pre_gelu_out
.
Ret
<
Buffer_Type
>
()
// lhs_swizzled
.
Ret
<
Buffer_Type
>
()
// rhs_swizzled
.
Ret
<
Buffer_Type
>
()
// workspace
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"lhs_axis_boundary"
)
.
Attr
<
int64_t
>
(
"rhs_axis_boundary"
)
.
Attr
<
bool
>
(
"lhs_transposed"
)
.
Attr
<
bool
>
(
"rhs_transposed"
)
.
Attr
<
bool
>
(
"fuse_bias"
)
.
Attr
<
bool
>
(
"fuse_gelu"
)
.
Attr
<
bool
>
(
"grad"
)
.
Attr
<
bool
>
(
"use_split_accumulator"
),
FFI_CudaGraph_Traits
);
Error_Type
GroupedGemmFFI
(
cudaStream_t
stream
,
Buffer_Type
lhs_data
,
Buffer_Type
lhs_sinv
,
Buffer_Type
rhs_data
,
Buffer_Type
rhs_sinv
,
Buffer_Type
bias
,
Buffer_Type
group_sizes
,
Buffer_Type
group_offset
,
Result_Type
output
,
...
...
@@ -54,15 +239,43 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
NVTE_CHECK
(
group_sizes
.
dimensions
().
size
()
==
1
);
size_t
num_gemms
=
group_sizes
.
dimensions
()[
0
];
// It is weird that TE/Common GEMM only use colwise for MXFP8
const
bool
is_fp8_gemm
=
is_fp8_dtype
(
lhs_dtype
);
const
bool
is_tensor_scaling
=
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
||
scaling_mode
==
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
;
const
bool
is_mxfp8_scaling
=
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
;
const
bool
rhs_use_colwise
=
is_mxfp8_scaling
&&
!
rhs_is_trans
;
const
bool
lhs_use_colwise
=
is_mxfp8_scaling
&&
lhs_is_trans
;
// Outputs
auto
out_ptr
=
reinterpret_cast
<
uint8_t
*>
(
output
->
untyped_data
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output
->
element_type
());
// Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned
auto
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
((
reinterpret_cast
<
uintptr_t
>
(
workspace
->
untyped_data
())
+
255
)
&
~
static_cast
<
uintptr_t
>
(
255
));
auto
workspace_total_size
=
product
(
workspace
->
dimensions
())
-
255
;
auto
workspace_size
=
workspace_total_size
/
num_streams
;
auto
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
->
untyped_data
());
workspace_ptr
=
move_ptr_to_next_256B_aligned
(
workspace_ptr
);
auto
workspace_total_size
=
product
(
workspace
->
dimensions
());
auto
lhs_sinv_size
=
product
(
lhs_sinv
.
dimensions
());
auto
rhs_sinv_size
=
product
(
rhs_sinv
.
dimensions
());
const
size_t
workspace_alignment_padding
=
256
;
const
size_t
tensor_scaling_sinv_aligment
=
16
;
const
size_t
mxfp8_scaling_sinv_alignment_padding
=
256
;
auto
workspace_size
=
workspace_total_size
-
workspace_alignment_padding
;
if
(
is_mxfp8_scaling
)
{
// For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4.
workspace_size
-=
(
lhs_sinv_size
+
rhs_sinv_size
+
2
*
mxfp8_scaling_sinv_alignment_padding
);
}
else
if
(
is_tensor_scaling
)
{
// For tensor scaling, each matrix has a single scale value, and all scales need to be aligned
// by 16 bytes to meet the requirement of CUDA 12.9.1 and later.
workspace_size
-=
tensor_scaling_sinv_aligment
*
(
lhs_sinv_size
+
rhs_sinv_size
);
}
workspace_size
=
workspace_size
/
num_streams
;
auto
swizzled_lhs_sinv_ptr
=
workspace_ptr
+
workspace_size
*
num_streams
;
swizzled_lhs_sinv_ptr
=
move_ptr_to_next_256B_aligned
(
swizzled_lhs_sinv_ptr
);
auto
swizzled_rhs_sinv_ptr
=
swizzled_lhs_sinv_ptr
+
lhs_sinv_size
;
swizzled_rhs_sinv_ptr
=
move_ptr_to_next_256B_aligned
(
swizzled_rhs_sinv_ptr
);
auto
lhs_scatter_aligned_ptr
=
swizzled_lhs_sinv_ptr
;
// Already 256B aligned
auto
rhs_scatter_aligned_ptr
=
lhs_scatter_aligned_ptr
+
num_gemms
*
tensor_scaling_sinv_aligment
;
size_t
lhs_dtype_bytes
=
te_dtype_bytes
(
lhs_dtype
);
size_t
rhs_dtype_bytes
=
te_dtype_bytes
(
rhs_dtype
);
...
...
@@ -71,6 +284,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t
bias_dtype_bytes
=
te_dtype_bytes
(
bias_dtype
);
size_t
out_dtype_bytes
=
te_dtype_bytes
(
out_dtype
);
if
(
is_tensor_scaling
)
{
cudaStream_t
stream_0
=
nvte_get_compute_stream
(
0
);
size_t
dpitch
=
tensor_scaling_sinv_aligment
;
size_t
spitch
=
lhs_sinv_dtype_bytes
;
size_t
width
=
lhs_sinv_dtype_bytes
;
size_t
height
=
lhs_sinv_size
;
cudaMemcpy2DAsync
(
lhs_scatter_aligned_ptr
,
dpitch
,
lhs_sinv_ptr
,
spitch
,
width
,
height
,
cudaMemcpyDeviceToDevice
,
stream_0
);
spitch
=
rhs_sinv_dtype_bytes
;
width
=
rhs_sinv_dtype_bytes
;
height
=
rhs_sinv_size
;
cudaMemcpy2DAsync
(
rhs_scatter_aligned_ptr
,
dpitch
,
rhs_sinv_ptr
,
spitch
,
width
,
height
,
cudaMemcpyDeviceToDevice
,
stream_0
);
lhs_sinv_ptr
=
lhs_scatter_aligned_ptr
;
rhs_sinv_ptr
=
rhs_scatter_aligned_ptr
;
}
NVTE_CHECK
(
lhs_dtype_bytes
==
rhs_dtype_bytes
,
"sizeof(lhs_dtype) != sizeof(rhs_dtype)"
);
NVTE_CHECK
(
lhs_sinv_dtype_bytes
==
rhs_sinv_dtype_bytes
,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"
);
...
...
@@ -120,12 +350,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto
bias_shape
=
std
::
vector
<
size_t
>
{
has_bias
?
n
:
0
};
const
int
arch
=
cuda
::
sm_arch
();
// It is weird that TE/Common GEMM only use colwise for MXFP8
const
bool
is_fp8_gemm
=
is_fp8_dtype
(
lhs_dtype
);
const
bool
is_mxfp8_scaling
=
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
;
const
bool
rhs_use_colwise
=
is_mxfp8_scaling
&&
!
rhs_is_trans
;
const
bool
lhs_use_colwise
=
is_mxfp8_scaling
&&
lhs_is_trans
;
if
(
arch
<
100
&&
is_fp8_gemm
)
{
NVTE_CHECK
(
!
lhs_is_trans
&&
rhs_is_trans
,
"For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, "
,
...
...
@@ -135,6 +359,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// These lists are to keep the TensorWrapper objects alive
std
::
vector
<
TensorWrapper
>
lhs_wrapper_list
;
std
::
vector
<
TensorWrapper
>
rhs_wrapper_list
;
std
::
vector
<
TensorWrapper
>
lhs_swizzle_wrapper_list
;
// For MXFP8 scale_inv swizzling
std
::
vector
<
TensorWrapper
>
rhs_swizzle_wrapper_list
;
std
::
vector
<
TensorWrapper
>
bias_wrapper_list
;
std
::
vector
<
TensorWrapper
>
pre_gelu_wrapper_list
;
std
::
vector
<
TensorWrapper
>
out_wrapper_list
;
...
...
@@ -143,66 +369,119 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std
::
vector
<
NVTETensor
>
lhs_list
;
std
::
vector
<
NVTETensor
>
rhs_list
;
std
::
vector
<
NVTETensor
>
lhs_swizzle_list
;
std
::
vector
<
NVTETensor
>
rhs_swizzle_list
;
std
::
vector
<
NVTETensor
>
bias_list
;
std
::
vector
<
NVTETensor
>
pre_gelu_list
;
std
::
vector
<
NVTETensor
>
out_list
;
std
::
vector
<
NVTETensor
>
workspace_list
;
size_t
lhs_sinv_total_size
=
0
;
size_t
rhs_sinv_total_size
=
0
;
std
::
vector
<
void
*>
zero_out_dptr_list
;
std
::
vector
<
size_t
>
zero_out_size_list
;
for
(
size_t
i
=
0
;
i
<
num_gemms
;
i
++
)
{
// Matrix data shapes
size_t
m_i
=
dim_list_host
[
i
];
auto
lhs_shape
=
std
::
vector
<
size_t
>
{
m_i
,
k
};
auto
rhs_shape
=
std
::
vector
<
size_t
>
{
rhs_is_trans
?
n
:
k
,
rhs_is_trans
?
k
:
n
};
auto
out_shape
=
std
::
vector
<
size_t
>
{
m_i
,
n
};
auto
lhs_shape
_i
=
std
::
vector
<
size_t
>
{
m_i
,
k
};
auto
rhs_shape
_i
=
std
::
vector
<
size_t
>
{
rhs_is_trans
?
n
:
k
,
rhs_is_trans
?
k
:
n
};
auto
out_shape
_i
=
std
::
vector
<
size_t
>
{
m_i
,
n
};
if
(
is_grouped_dense_wgrad
)
{
size_t
k_i
=
dim_list_host
[
i
];
lhs_shape
[
0
]
=
lhs_is_trans
?
k_i
:
m
;
lhs_shape
[
1
]
=
lhs_is_trans
?
m
:
k_i
;
rhs_shape
[
0
]
=
rhs_is_trans
?
n
:
k_i
;
rhs_shape
[
1
]
=
rhs_is_trans
?
k_i
:
n
;
out_shape
[
0
]
=
m
;
out_shape
[
1
]
=
n
;
lhs_shape_i
[
0
]
=
lhs_is_trans
?
k_i
:
m
;
lhs_shape_i
[
1
]
=
lhs_is_trans
?
m
:
k_i
;
rhs_shape_i
[
0
]
=
rhs_is_trans
?
n
:
k_i
;
rhs_shape_i
[
1
]
=
rhs_is_trans
?
k_i
:
n
;
out_shape_i
[
0
]
=
m
;
out_shape_i
[
1
]
=
n
;
}
size_t
lhs_size
=
lhs_shape_i
[
0
]
*
lhs_shape_i
[
1
];
size_t
rhs_size
=
rhs_shape_i
[
0
]
*
rhs_shape_i
[
1
];
size_t
out_size
=
out_shape_i
[
0
]
*
out_shape_i
[
1
];
bool
is_empty_gemm
=
lhs_size
==
0
||
rhs_size
==
0
;
if
(
is_empty_gemm
&&
out_size
>
0
)
{
zero_out_dptr_list
.
push_back
(
out_ptr
);
zero_out_size_list
.
push_back
(
out_size
*
out_dtype_bytes
);
}
// Set matrix data pointers
auto
lhs_i
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
auto
rhs_i
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
auto
out_i
=
TensorWrapper
(
static_cast
<
void
*>
(
out_ptr
),
out_shape
,
out_dtype
);
auto
out_i
=
TensorWrapper
(
static_cast
<
void
*>
(
out_ptr
),
out_shape
_i
,
out_dtype
);
void
*
lhs_vptr
=
static_cast
<
void
*>
(
lhs_ptr
);
void
*
rhs_vptr
=
static_cast
<
void
*>
(
rhs_ptr
);
if
(
rhs_use_colwise
)
// MatA to enter cuBLAS
rhs_i
.
set_columnwise_data
(
rhs_vptr
,
rhs_dtype
,
rhs_shape
);
rhs_i
.
set_columnwise_data
(
rhs_vptr
,
rhs_dtype
,
rhs_shape
_i
);
else
rhs_i
.
set_rowwise_data
(
rhs_vptr
,
rhs_dtype
,
rhs_shape
);
rhs_i
.
set_rowwise_data
(
rhs_vptr
,
rhs_dtype
,
rhs_shape
_i
);
if
(
lhs_use_colwise
)
// MatB to enter cuBLAS
lhs_i
.
set_columnwise_data
(
lhs_vptr
,
lhs_dtype
,
lhs_shape
);
lhs_i
.
set_columnwise_data
(
lhs_vptr
,
lhs_dtype
,
lhs_shape
_i
);
else
lhs_i
.
set_rowwise_data
(
lhs_vptr
,
lhs_dtype
,
lhs_shape
);
// Scale_inv shapes
auto
lhs_sinv_size
=
std
::
vector
<
size_t
>
{
1
};
auto
rhs_sinv_size
=
std
::
vector
<
size_t
>
{
1
};
if
(
is_mxfp8_scaling
)
{
NVTE_CHECK
(
k
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 K-dim being divisble by %d (got %d)"
,
MXFP8_BLOCK_SIZE
,
k
);
size_t
scale_k
=
k
/
MXFP8_BLOCK_SIZE
;
lhs_sinv_size
[
0
]
=
m_i
*
scale_k
;
rhs_sinv_size
[
0
]
=
n
*
scale_k
;
// Need to add swizzle here
}
lhs_i
.
set_rowwise_data
(
lhs_vptr
,
lhs_dtype
,
lhs_shape_i
);
// Set scale_inv pointers
// Set scale_inv
shapes and
pointers
void
*
rhs_sinv_vptr
=
static_cast
<
void
*>
(
rhs_sinv_ptr
);
void
*
lhs_sinv_vptr
=
static_cast
<
void
*>
(
lhs_sinv_ptr
);
if
(
is_fp8_gemm
)
{
size_t
lhs_sinv_size_i
=
0
;
size_t
rhs_sinv_size_i
=
0
;
if
(
is_tensor_scaling
)
{
auto
tensor_scaling_sinv_shape
=
std
::
vector
<
size_t
>
{
1
};
// If is_empty_gemm, scale_inv does not have the corresponding value, do not move the pointers
if
(
!
is_empty_gemm
)
{
lhs_sinv_size_i
=
tensor_scaling_sinv_aligment
/
lhs_sinv_dtype_bytes
;
rhs_sinv_size_i
=
tensor_scaling_sinv_aligment
/
rhs_sinv_dtype_bytes
;
}
if
(
rhs_use_colwise
)
// MatA to enter cuBLAS
rhs_i
.
set_columnwise_scale_inv
(
rhs_sinv_vptr
,
rhs_sinv_dtype
,
rhs
_sinv_s
iz
e
);
rhs_i
.
set_columnwise_scale_inv
(
rhs_sinv_vptr
,
rhs_sinv_dtype
,
tensor_scaling
_sinv_s
hap
e
);
else
rhs_i
.
set_rowwise_scale_inv
(
rhs_sinv_vptr
,
rhs_sinv_dtype
,
rhs
_sinv_s
iz
e
);
rhs_i
.
set_rowwise_scale_inv
(
rhs_sinv_vptr
,
rhs_sinv_dtype
,
tensor_scaling
_sinv_s
hap
e
);
if
(
lhs_use_colwise
)
// MatB to enter cuBLAS
lhs_i
.
set_columnwise_scale_inv
(
lhs_sinv_vptr
,
lhs_sinv_dtype
,
lhs_s
in
v
_si
z
e
);
lhs_i
.
set_columnwise_scale_inv
(
lhs_sinv_vptr
,
lhs_sinv_dtype
,
tensor_scal
in
g
_si
nv_shap
e
);
else
lhs_i
.
set_rowwise_scale_inv
(
lhs_sinv_vptr
,
lhs_sinv_dtype
,
lhs_sinv_size
);
lhs_i
.
set_rowwise_scale_inv
(
lhs_sinv_vptr
,
lhs_sinv_dtype
,
tensor_scaling_sinv_shape
);
}
else
if
(
is_mxfp8_scaling
)
{
auto
lhs_swizzle_i
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
auto
rhs_swizzle_i
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
void
*
swizzled_lhs_sinv_vptr
=
static_cast
<
void
*>
(
swizzled_lhs_sinv_ptr
);
void
*
swizzled_rhs_sinv_vptr
=
static_cast
<
void
*>
(
swizzled_rhs_sinv_ptr
);
// {lhs, rhs}_swizzle_i point to unswizzled scale_inv data as input, while {lhs, rhs}_i
// point to swizzled scale_inv data (store on workspace, only used for GEMM).
// Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers
auto
lhs_sinv_shape_i
=
get_mxfp8_scale_shape
(
lhs_shape_i
[
0
],
lhs_shape_i
[
1
],
lhs_use_colwise
);
auto
rhs_sinv_shape_i
=
get_mxfp8_scale_shape
(
rhs_shape_i
[
0
],
rhs_shape_i
[
1
],
rhs_use_colwise
);
lhs_sinv_size_i
=
lhs_sinv_shape_i
[
0
]
*
lhs_sinv_shape_i
[
1
];
rhs_sinv_size_i
=
rhs_sinv_shape_i
[
0
]
*
rhs_sinv_shape_i
[
1
];
if
(
lhs_use_colwise
)
{
lhs_swizzle_i
.
set_columnwise_data
(
lhs_vptr
,
lhs_dtype
,
lhs_shape_i
);
lhs_swizzle_i
.
set_columnwise_scale_inv
(
lhs_sinv_vptr
,
lhs_sinv_dtype
,
lhs_sinv_shape_i
);
lhs_i
.
set_columnwise_scale_inv
(
swizzled_lhs_sinv_vptr
,
lhs_sinv_dtype
,
lhs_sinv_shape_i
);
}
else
{
lhs_swizzle_i
.
set_rowwise_data
(
lhs_vptr
,
lhs_dtype
,
lhs_shape_i
);
lhs_swizzle_i
.
set_rowwise_scale_inv
(
lhs_sinv_vptr
,
lhs_sinv_dtype
,
lhs_sinv_shape_i
);
lhs_i
.
set_rowwise_scale_inv
(
swizzled_lhs_sinv_vptr
,
lhs_sinv_dtype
,
lhs_sinv_shape_i
);
}
if
(
rhs_use_colwise
)
{
rhs_swizzle_i
.
set_columnwise_data
(
rhs_vptr
,
rhs_dtype
,
rhs_shape_i
);
rhs_swizzle_i
.
set_columnwise_scale_inv
(
rhs_sinv_vptr
,
rhs_sinv_dtype
,
rhs_sinv_shape_i
);
rhs_i
.
set_columnwise_scale_inv
(
swizzled_rhs_sinv_vptr
,
rhs_sinv_dtype
,
rhs_sinv_shape_i
);
}
else
{
rhs_swizzle_i
.
set_rowwise_data
(
rhs_vptr
,
rhs_dtype
,
rhs_shape_i
);
rhs_swizzle_i
.
set_rowwise_scale_inv
(
rhs_sinv_vptr
,
rhs_sinv_dtype
,
rhs_sinv_shape_i
);
rhs_i
.
set_rowwise_scale_inv
(
swizzled_rhs_sinv_vptr
,
rhs_sinv_dtype
,
rhs_sinv_shape_i
);
}
if
(
!
is_empty_gemm
)
{
lhs_swizzle_wrapper_list
.
push_back
(
std
::
move
(
lhs_swizzle_i
));
rhs_swizzle_wrapper_list
.
push_back
(
std
::
move
(
rhs_swizzle_i
));
lhs_swizzle_list
.
push_back
(
lhs_swizzle_wrapper_list
.
back
().
data
());
rhs_swizzle_list
.
push_back
(
rhs_swizzle_wrapper_list
.
back
().
data
());
}
}
else
{
NVTE_CHECK
(
scaling_mode
==
JAXX_Scaling_Mode
::
NO_SCALING
,
"Unsupported scaling mode: "
,
static_cast
<
int
>
(
scaling_mode
));
...
...
@@ -212,16 +491,23 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
auto
pre_gelu_i
=
TensorWrapper
(
nullptr
,
std
::
vector
<
size_t
>
{
0
},
out_dtype
);
// Update pointer for the next GEMM pair
lhs_ptr
+=
lhs_s
hape
[
0
]
*
lhs_shape
[
1
]
*
lhs_dtype_bytes
;
rhs_ptr
+=
rhs_s
hape
[
0
]
*
rhs_shape
[
1
]
*
rhs_dtype_bytes
;
out_ptr
+=
out_s
hape
[
0
]
*
out_shape
[
1
]
*
out_dtype_bytes
;
lhs_ptr
+=
lhs_s
ize
*
lhs_dtype_bytes
;
rhs_ptr
+=
rhs_s
ize
*
rhs_dtype_bytes
;
out_ptr
+=
out_s
ize
*
out_dtype_bytes
;
if
(
is_fp8_gemm
)
{
lhs_sinv_ptr
+=
lhs_sinv_size
[
0
]
*
lhs_sinv_dtype_bytes
;
rhs_sinv_ptr
+=
rhs_sinv_size
[
0
]
*
rhs_sinv_dtype_bytes
;
lhs_sinv_ptr
+=
lhs_sinv_size_i
*
lhs_sinv_dtype_bytes
;
rhs_sinv_ptr
+=
rhs_sinv_size_i
*
rhs_sinv_dtype_bytes
;
lhs_sinv_total_size
+=
lhs_sinv_size_i
;
rhs_sinv_total_size
+=
rhs_sinv_size_i
;
if
(
is_mxfp8_scaling
)
{
swizzled_lhs_sinv_ptr
+=
lhs_sinv_size_i
*
lhs_sinv_dtype_bytes
;
swizzled_rhs_sinv_ptr
+=
rhs_sinv_size_i
*
rhs_sinv_dtype_bytes
;
}
}
if
(
has_bias
)
bias_ptr
+=
n
*
bias_dtype_bytes
;
// Move objects to the lists to keep them alive
if
(
is_empty_gemm
)
continue
;
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
out_wrapper_list
.
push_back
(
std
::
move
(
out_i
));
...
...
@@ -244,10 +530,45 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
workspace_ptr
+=
workspace_size
;
}
if
(
is_fp8_gemm
)
{
if
(
is_tensor_scaling
)
{
lhs_sinv_size
*=
tensor_scaling_sinv_aligment
;
rhs_sinv_size
*=
tensor_scaling_sinv_aligment
;
}
NVTE_CHECK
(
lhs_sinv_total_size
<=
lhs_sinv_size
,
"Actual total lhs_sinv size "
,
lhs_sinv_total_size
,
" exceeds estimated upper bound "
,
lhs_sinv_size
);
NVTE_CHECK
(
rhs_sinv_total_size
<=
rhs_sinv_size
,
"Actual total rhs_sinv size "
,
rhs_sinv_total_size
,
" exceeds estimated upper bound "
,
rhs_sinv_size
);
}
size_t
num_non_empty_gemms
=
lhs_list
.
size
();
if
(
is_mxfp8_scaling
)
{
for
(
int
i
=
0
;
i
<
num_non_empty_gemms
;
i
++
)
{
// The i-th GEMM will use the (i % num_streams)-th stream to compute,
// use the same stream to swizzle the scaling factors to make sure that
// the swizzling is done before the GEMM computation starts.
int
stream_id
=
i
%
num_streams
;
cudaStream_t
stream_i
=
nvte_get_compute_stream
(
stream_id
);
nvte_swizzle_scaling_factors
(
lhs_swizzle_list
[
i
],
lhs_list
[
i
],
stream_i
);
nvte_swizzle_scaling_factors
(
rhs_swizzle_list
[
i
],
rhs_list
[
i
],
stream_i
);
}
}
// Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM
size_t
num_zero_outs
=
zero_out_dptr_list
.
size
();
for
(
int
i
=
0
;
i
<
num_zero_outs
;
i
++
)
{
int
stream_id
=
i
%
num_streams
;
cudaStream_t
stream_i
=
nvte_get_compute_stream
(
stream_id
);
void
*
dptr
=
zero_out_dptr_list
[
i
];
size_t
count
=
zero_out_size_list
[
i
];
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
dptr
,
0
,
count
,
stream_i
));
}
nvte_multi_stream_cublas_gemm
(
rhs_list
.
data
(),
lhs_list
.
data
(),
out_list
.
data
(),
bias_list
.
data
(),
pre_gelu_list
.
data
(),
num_gemms
,
rhs_is_trans
,
lhs_is_trans
,
grad
,
workspace_list
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sm
,
stream
);
pre_gelu_list
.
data
(),
num_
non_empty_
gemms
,
rhs_is_trans
,
lhs_is_trans
,
grad
,
workspace_list
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sm
,
stream
);
return
ffi_with_cuda_error_check
();
}
...
...
transformer_engine/jax/csrc/extensions/misc.h
View file @
44740c6c
...
...
@@ -47,6 +47,15 @@ enum class JAXX_Scaling_Mode : int64_t {
CURRENT_TENSOR_SCALING
=
3
,
};
inline
bool
is_tensor_scaling
(
const
JAXX_Scaling_Mode
&
mode
)
{
return
(
mode
==
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
||
mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
);
}
inline
bool
is_block_scaling
(
const
JAXX_Scaling_Mode
&
mode
)
{
return
(
mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
);
}
static
NVTEScalingMode
get_nvte_scaling_mode
(
const
JAXX_Scaling_Mode
&
mode
)
{
switch
(
mode
)
{
case
JAXX_Scaling_Mode
::
NO_SCALING
:
...
...
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
44740c6c
...
...
@@ -55,6 +55,11 @@ pybind11::dict Registrations() {
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
FusedAttnBackwardHandler
));
// GEMM
dict
[
"te_gemm_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CublasHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
GemmHandler
));
// Grouped GEMM
dict
[
"te_grouped_gemm_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CublasHandleInitHandler
),
...
...
@@ -78,6 +83,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
m
.
def
(
"get_fused_attn_fwd_workspace_sizes"
,
&
GetFusedAttnForwardWorkspaceSizes
);
m
.
def
(
"get_fused_attn_bwd_workspace_sizes"
,
&
GetFusedAttnBackwardWorkspaceSizes
);
m
.
def
(
"nvte_get_qkv_format"
,
&
nvte_get_qkv_format
);
m
.
def
(
"is_non_nt_fp8_gemm_supported"
,
&
nvte_is_non_tn_fp8_gemm_supported
);
pybind11
::
enum_
<
DType
>
(
m
,
"DType"
,
pybind11
::
module_local
())
.
value
(
"kByte"
,
DType
::
kByte
)
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment