Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1093 additions
and
1346 deletions
+1093
-1346
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+14
-8
transformer_engine/jax/csrc/extensions/activation.cpp
transformer_engine/jax/csrc/extensions/activation.cpp
+111
-87
transformer_engine/jax/csrc/extensions/gemm.cpp
transformer_engine/jax/csrc/extensions/gemm.cpp
+96
-110
transformer_engine/jax/csrc/extensions/misc.h
transformer_engine/jax/csrc/extensions/misc.h
+24
-1
transformer_engine/jax/csrc/extensions/normalization.cpp
transformer_engine/jax/csrc/extensions/normalization.cpp
+10
-11
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+9
-9
transformer_engine/jax/csrc/extensions/quantization.cpp
transformer_engine/jax/csrc/extensions/quantization.cpp
+93
-41
transformer_engine/jax/dense.py
transformer_engine/jax/dense.py
+37
-22
transformer_engine/jax/flax/module.py
transformer_engine/jax/flax/module.py
+90
-102
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+12
-4
transformer_engine/jax/layernorm_dense.py
transformer_engine/jax/layernorm_dense.py
+19
-9
transformer_engine/jax/layernorm_mlp.py
transformer_engine/jax/layernorm_mlp.py
+54
-21
transformer_engine/jax/praxis/__init__.py
transformer_engine/jax/praxis/__init__.py
+0
-9
transformer_engine/jax/praxis/module.py
transformer_engine/jax/praxis/module.py
+0
-311
transformer_engine/jax/praxis/transformer.py
transformer_engine/jax/praxis/transformer.py
+0
-408
transformer_engine/jax/quantize/dequantizer.py
transformer_engine/jax/quantize/dequantizer.py
+16
-7
transformer_engine/jax/quantize/helper.py
transformer_engine/jax/quantize/helper.py
+70
-35
transformer_engine/jax/quantize/quantizer.py
transformer_engine/jax/quantize/quantizer.py
+102
-68
transformer_engine/jax/quantize/scaling_modes.py
transformer_engine/jax/quantize/scaling_modes.py
+200
-42
transformer_engine/jax/quantize/tensor.py
transformer_engine/jax/quantize/tensor.py
+136
-41
No files found.
transformer_engine/jax/csrc/extensions.h
View file @
ab3e5a92
...
@@ -31,6 +31,9 @@
...
@@ -31,6 +31,9 @@
#include "transformer_engine/activation.h"
#include "transformer_engine/activation.h"
#include "utils.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING
(
transformer_engine
::
jax
::
JAXX_Scaling_Mode
);
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
...
@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
...
@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ActLuHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ActLuHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuDBiasQuantizeHandler
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
);
// Normalization
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
NormForwardHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
NormForwardHandler
);
...
@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
...
@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
DType
w_dtype
,
DType
out_dtype
,
NVTE_Norm_Type
norm_type
,
int
scaling_mode
,
NVTE_Norm_Type
norm_type
,
JAXX_Scaling_Mode
scaling_mode
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
is_training
);
bool
is_training
);
...
@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
...
@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DequantizeHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DequantizeHandler
);
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
);
DType
in_dtype
,
DType
out_dtype
,
JAXX_Scaling_Mode
scaling_mode
,
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuDBiasQuantizeHandler
);
QuantizeLayout
q_layout
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
int
scaling_mode
,
bool
is_2x
);
// Softmax
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ScaledSoftmaxForwardHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ScaledSoftmaxForwardHandler
);
...
...
transformer_engine/jax/csrc/extensions/activation.cpp
View file @
ab3e5a92
...
@@ -11,21 +11,13 @@
...
@@ -11,21 +11,13 @@
#include "transformer_engine/cast.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/c_api.h"
namespace
{
bool
is_gated
(
NVTE_Activation_Type
act_type
)
{
return
act_type
==
NVTE_Activation_Type
::
GEGLU
||
act_type
==
NVTE_Activation_Type
::
SWIGLU
||
act_type
==
NVTE_Activation_Type
::
REGLU
||
act_type
==
NVTE_Activation_Type
::
QGEGLU
||
act_type
==
NVTE_Activation_Type
::
SREGLU
;
}
}
// namespace
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
Error_Type
ActLuFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Error_Type
ActLuFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
colwise_output_buf
,
Result_Type
output_buf
,
Result_Type
colwise_output_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
int64_t
act_enum
,
int64_t
scaling_mode
_enum
,
Result_Type
amax_buf
,
int64_t
act_enum
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x_int
)
{
bool
is_2x_int
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
...
@@ -42,40 +34,59 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
...
@@ -42,40 +34,59 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto
n
=
input_dims
.
back
();
auto
n
=
input_dims
.
back
();
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
auto
act_len
=
input_dims
[
input_dims
.
size
()
-
2
];
auto
act_len
=
input_dims
[
input_dims
.
size
()
-
2
];
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
is_2x
=
static_cast
<
bool
>
(
is_2x_int
);
auto
is_2x
=
static_cast
<
bool
>
(
is_2x_int
);
auto
flatten_axis
=
output_buf
->
dimensions
().
size
()
-
1
;
// output does not have act axis
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
act_len
*
n
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
act_len
*
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_
scaling_mode
(
scaling_mode
)
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
scale_inv_buf
->
untyped_data
(),
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
std
::
vector
<
size_t
>
{
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
scale_inv_buf
->
dimensions
().
back
()});
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
}
else
{
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
output_tensor
.
set_rowwise_scale_inv
(
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
scale_inv_buf
->
untyped_data
(),
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
scale_inv_buf
->
dimensions
(),
flatten_axis
,
scale_inv_buf
->
dimensions
().
size
())});
}
}
}
if
(
is_2x
)
{
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
colwise_output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
auto
&
tmp_shape
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
output_tensor
.
set_columnwise_scale_inv
(
?
output_trans_shape
colwise_scale_inv_buf
->
untyped_data
(),
:
output_shape
;
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
output_tensor
.
set_columnwise_data
(
colwise_output
,
out_dtype
,
tmp_shape
);
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
if
(
is_fp8_dtype
(
out_dtype
))
{
colwise_scale_inv_buf
->
dimensions
().
back
()});
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
colwise_scale_inv_buf
;
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
tmp_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
tmp_buf
->
dimensions
(),
flatten_axis
,
tmp_buf
->
dimensions
().
size
())});
}
}
}
}
switch
(
act_type
)
{
switch
(
act_type
)
{
...
@@ -128,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
...
@@ -128,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.
Ret
<
Buffer_Type
>
()
// scale_inv colwise
.
Ret
<
Buffer_Type
>
()
// scale_inv colwise
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// amax
.
Attr
<
int64_t
>
(
"act_enum"
)
.
Attr
<
int64_t
>
(
"act_enum"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
),
.
Attr
<
bool
>
(
"is_2x"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
DType
in_dtype
,
DType
out_dtype
,
int
scaling_mode
,
bool
is_2x
)
{
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
dact_input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
dact_input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
...
@@ -153,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
...
@@ -153,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto
dact_input_tensor
=
auto
dact_input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dact_input_shape
,
in_dtype
);
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dact_input_shape
,
in_dtype
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
static_cast
<
NVTES
caling
M
ode
>
(
scaling_mode
));
auto
output_tensor
=
TensorWrapper
(
get_nvte_s
caling
_m
ode
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
is_fp8_dtype
(
out_dtype
))
{
...
@@ -162,8 +173,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
...
@@ -162,8 +173,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
}
if
(
is_2x
)
{
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
auto
&
tmp_shape
=
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
?
output_trans_shape
output_trans_shape
);
:
output_shape
;
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
tmp_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
is_fp8_dtype
(
out_dtype
))
{
...
@@ -172,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
...
@@ -172,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
}
}
}
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
NVTE
ScalingMode
::
NVTE_
DELAYED_TENSOR_SCALING
)
{
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_amax
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
output_tensor
.
set_amax
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
output_tensor
.
set_scale
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
...
@@ -190,22 +202,25 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
...
@@ -190,22 +202,25 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type
DActLuDBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Error_Type
DActLuDBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
act_input_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
act_input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
output
_trans
_buf
,
Result_Type
output_buf
,
Result_Type
colwise_
output_buf
,
Result_Type
scale_inv_buf
,
Result_Type
trans
_scale_inv_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise
_scale_inv_buf
,
Result_Type
amax_
out_
buf
,
Result_Type
dbias_buf
,
Result_Type
amax_buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
int64_t
scaling_mode_enum
,
bool
is_2x
,
Result_Type
workspace_buf
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_dbias
,
int64_t
act_enum
)
{
int64_t
act_enum
,
bool
is_2x
,
bool
is_dbias
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
act_input
=
act_input_buf
.
untyped_data
();
auto
*
act_input
=
act_input_buf
.
untyped_data
();
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
auto
flatten_axis
=
output_buf
->
dimensions
().
size
()
-
2
;
// output has act axis
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans
_buf
->
untyped_data
();
auto
*
colwise_output
=
colwise_output
_buf
->
untyped_data
();
auto
*
dbias
=
dbias_buf
->
untyped_data
();
auto
*
dbias
=
dbias_buf
->
untyped_data
();
void
*
workspace
=
workspace_buf
->
untyped_data
();
void
*
workspace
=
workspace_buf
->
untyped_data
();
...
@@ -213,67 +228,76 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
...
@@ -213,67 +228,76 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto
act_input_dims
=
act_input_buf
.
dimensions
();
auto
act_input_dims
=
act_input_buf
.
dimensions
();
auto
workspace_dims
=
workspace_buf
->
dimensions
();
auto
workspace_dims
=
workspace_buf
->
dimensions
();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto
input_ranks
=
input_dims
.
size
();
auto
act_len
=
act_input_dims
[
act_input_dims
.
size
()
-
2
];
auto
act_input_ranks
=
act_input_dims
.
size
();
NVTE_CHECK
(
act_input_dims
.
back
()
==
input_dims
.
back
(),
auto
m
=
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
1
);
"Shape mismatch between activation input and gradient input"
);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto
m
=
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
2
);
auto
n
=
act_input_dims
.
back
();
auto
n
=
input_dims
.
back
();
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
input_dims
.
back
()};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
*
act_len
,
m
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
*
act_len
};
std
::
vector
<
size_t
>
workspace_shape
(
workspace_dims
.
begin
(),
workspace_dims
.
end
());
std
::
vector
<
size_t
>
workspace_shape
(
workspace_dims
.
begin
(),
workspace_dims
.
end
());
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
_out
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax
_out
,
0
,
sizeof
(
float
),
stream
);
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax_out
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
scale_inv_buf
->
dimensions
(),
flatten_axis
,
scale_inv_buf
->
dimensions
().
size
())});
}
}
}
}
if
(
is_2x
)
{
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
auto
&
tmp_shape
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
colwise_output
,
out_dtype
,
tmp_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
is_fp8_dtype
(
out_dtype
))
{
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
colwise_scale_inv_buf
=
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
trans_scale_inv_buf
;
?
scale_inv_buf
output_tensor
.
set_columnwise_scale_inv
(
:
colwise_scale_inv_buf
;
colwise_scale_inv_buf
->
untyped_data
(),
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
output_tensor
.
set_columnwise_scale_inv
(
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
std
::
vector
<
size_t
>
{
1
});
colwise_scale_inv_buf
->
dimensions
().
back
()});
}
else
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
tmp_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
tmp_buf
->
dimensions
(),
flatten_axis
,
tmp_buf
->
dimensions
().
size
())});
}
}
}
}
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK
(
!
(
is_gated
(
act_type
)
&&
is_dbias
),
"Unsupported DGatedActedDBias Fusion!"
);
NVTE_CHECK
(
!
(
act_len
==
2
&&
is_dbias
),
"Unsupported DGatedActedDBias Fusion!"
);
NVTE_CHECK
(
!
(
scaling_mode
==
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
&&
is_2x
&&
NVTE_CHECK
(
!
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
&&
is_2x
&&
act_len
==
2
),
is_gated
(
act_type
)),
"TE/common does not support delayed scaling for 2x with gated activations."
);
"TE/common does not support delayed scaling for 2x with gated activations."
);
if
(
is_dbias
)
{
if
(
is_dbias
)
{
...
@@ -361,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
...
@@ -361,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"act_enum"
)
.
Attr
<
bool
>
(
"is_2x"
)
.
Attr
<
bool
>
(
"is_2x"
)
.
Attr
<
bool
>
(
"is_dbias"
)
.
Attr
<
bool
>
(
"is_dbias"
),
.
Attr
<
int64_t
>
(
"act_enum"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/gemm.cpp
View file @
ab3e5a92
...
@@ -15,47 +15,34 @@
...
@@ -15,47 +15,34 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
constexpr
static
size_t
MXFP8_BLOCK_SIZE
=
32
;
Error_Type
GroupedGemmFFI
(
cudaStream_t
stream
,
Variadic_Buffer_Type
input_list
,
Variadic_Result_Type
output_list
,
int64_t
num_gemms
,
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
JAXX_Scaling_Mode
scaling_mode
,
int64_t
has_bias
)
{
Error_Type
GroupedGemmImpl
(
uint8_t
*
lhs_ptr
,
const
DType
&
lhs_dtype
,
uint8_t
*
lhs_sinv_ptr
,
const
DType
&
lhs_sinv_dtype
,
uint8_t
*
rhs_ptr
,
const
DType
&
rhs_dtype
,
uint8_t
*
rhs_sinv_ptr
,
const
DType
&
rhs_sinv_dtype
,
uint8_t
*
bias_ptr
,
const
DType
&
bias_dtype
,
uint8_t
*
out_ptr
,
const
DType
&
out_dtype
,
uint8_t
*
workspace_ptr
,
const
size_t
workspace_size
,
size_t
num_gemms
,
int32_t
*
dim_list_ptr
,
const
int64_t
&
scaling_mode
,
cudaStream_t
stream
)
{
size_t
lhs_dtype_bytes
=
te_dtype_bytes
(
lhs_dtype
);
size_t
rhs_dtype_bytes
=
te_dtype_bytes
(
rhs_dtype
);
size_t
lhs_sinv_dtype_bytes
=
te_dtype_bytes
(
lhs_sinv_dtype
);
size_t
rhs_sinv_dtype_bytes
=
te_dtype_bytes
(
rhs_sinv_dtype
);
size_t
bias_dtype_bytes
=
te_dtype_bytes
(
bias_dtype
);
size_t
out_dtype_bytes
=
te_dtype_bytes
(
out_dtype
);
NVTE_CHECK
(
lhs_dtype_bytes
==
rhs_dtype_bytes
,
"sizeof(lhs_dtype) != sizeof(rhs_dtype)"
);
NVTE_CHECK
(
lhs_sinv_dtype_bytes
==
rhs_sinv_dtype_bytes
,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"
);
size_t
dim_list_bytes
=
sizeof
(
int32_t
)
*
3
*
num_gemms
;
std
::
unique_ptr
<
int32_t
[]
>
dim_list_host
=
std
::
make_unique
<
int32_t
[]
>
(
3
*
num_gemms
);
cudaMemcpyAsync
(
dim_list_host
.
get
(),
dim_list_ptr
,
dim_list_bytes
,
cudaMemcpyDeviceToHost
,
stream
);
// Note: This may break cudaGraph.
cudaStreamSynchronize
(
stream
);
// Notes on matrix layouts and transpose:
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// Jax uses row-major
data_
layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// cuBLAS uses column-major
data_
layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if
(
num_gemms
<=
0
)
{
return
ffi_with_cuda_error_check
();
}
size_t
expected_input_size
=
has_bias
?
5
*
num_gemms
:
4
*
num_gemms
;
size_t
expected_output_size
=
num_gemms
+
1
;
size_t
actual_input_size
=
input_list
.
size
();
size_t
actual_output_size
=
output_list
.
size
();
NVTE_CHECK
(
actual_input_size
==
expected_input_size
,
"Expected %zu input tensors, got %zu"
,
expected_input_size
,
actual_input_size
);
NVTE_CHECK
(
actual_output_size
==
expected_output_size
,
"Expected %zu output tensors, got %zu"
,
expected_output_size
,
actual_output_size
);
bool
trans_lhs
=
true
;
bool
trans_lhs
=
true
;
bool
trans_rhs
=
false
;
bool
trans_rhs
=
false
;
auto
num_math_sm
=
cuda
::
sm_count
()
-
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
0
);
auto
num_math_sm
=
cuda
::
sm_count
()
-
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
0
);
...
@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
...
@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
std
::
vector
<
NVTETensor
>
out_list
;
std
::
vector
<
NVTETensor
>
out_list
;
std
::
vector
<
NVTETensor
>
workspace_list
;
std
::
vector
<
NVTETensor
>
workspace_list
;
int
lhs_list_offset
=
0
;
int
rhs_list_offset
=
num_gemms
;
int
lhs_sinv_list_offset
=
2
*
num_gemms
;
int
rhs_sinv_list_offset
=
3
*
num_gemms
;
int
bias_list_offset
=
4
*
num_gemms
;
int
out_list_offset
=
0
;
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
size_t
m
=
dim_list_host
[
i
*
3
];
Buffer_Type
lhs_i
=
input_list
.
get
<
Buffer_Type
>
(
lhs_list_offset
+
i
).
value
();
size_t
n
=
dim_list_host
[
i
*
3
+
1
];
Buffer_Type
rhs_i
=
input_list
.
get
<
Buffer_Type
>
(
rhs_list_offset
+
i
).
value
();
size_t
k
=
dim_list_host
[
i
*
3
+
2
];
Buffer_Type
lhs_sinv_i
=
input_list
.
get
<
Buffer_Type
>
(
lhs_sinv_list_offset
+
i
).
value
();
Buffer_Type
rhs_sinv_i
=
input_list
.
get
<
Buffer_Type
>
(
rhs_sinv_list_offset
+
i
).
value
();
Result_Type
out_i
=
output_list
.
get
<
Buffer_Type
>
(
out_list_offset
+
i
).
value
();
DType
lhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_i
.
element_type
());
DType
rhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_i
.
element_type
());
DType
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
out_i
->
element_type
());
void
*
lhs_ptr
=
lhs_i
.
untyped_data
();
void
*
rhs_ptr
=
rhs_i
.
untyped_data
();
void
*
lhs_sinv_ptr
=
lhs_sinv_i
.
untyped_data
();
void
*
rhs_sinv_ptr
=
rhs_sinv_i
.
untyped_data
();
void
*
out_ptr
=
out_i
->
untyped_data
();
// Placeholder for bias since it can be empty
DType
bias_dtype
=
DType
::
kFloat32
;
void
*
bias_ptr
=
nullptr
;
auto
lhs_shape_
=
lhs_i
.
dimensions
();
auto
rhs_shape_
=
rhs_i
.
dimensions
();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t
m
=
lhs_shape_
[
1
];
size_t
n
=
rhs_shape_
[
1
];
size_t
k
=
lhs_shape_
[
2
];
auto
lhs_shape
=
std
::
vector
<
size_t
>
{
m
,
k
};
auto
lhs_shape
=
std
::
vector
<
size_t
>
{
m
,
k
};
auto
rhs_shape
=
std
::
vector
<
size_t
>
{
n
,
k
};
auto
rhs_shape
=
std
::
vector
<
size_t
>
{
n
,
k
};
...
@@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
...
@@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto
lhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
auto
lhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
auto
rhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
auto
rhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
if
(
scaling_mode
==
NVTE_NO_SCALING
||
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
NO_SCALING
||
auto
lhs_i
=
TensorWrapper
(
static_cast
<
void
*>
(
lhs_ptr
),
lhs_shape
,
lhs_dtype
,
nullptr
,
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
nullptr
,
reinterpret_cast
<
float
*>
(
lhs_sinv_ptr
));
float
*
amax_dptr
=
nullptr
;
auto
rhs_i
=
TensorWrapper
(
static_cast
<
void
*>
(
rhs_ptr
),
rhs_shape
,
rhs_dtype
,
nullptr
,
float
*
scale_dptr
=
nullptr
;
nullptr
,
reinterpret_cast
<
float
*>
(
rhs_sinv_ptr
));
auto
lhs_i_
=
TensorWrapper
(
lhs_ptr
,
lhs_shape
,
lhs_dtype
,
amax_dptr
,
scale_dptr
,
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
reinterpret_cast
<
float
*>
(
lhs_sinv_ptr
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
auto
rhs_i_
=
TensorWrapper
(
rhs_ptr
,
rhs_shape
,
rhs_dtype
,
amax_dptr
,
scale_dptr
,
}
else
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
reinterpret_cast
<
float
*>
(
rhs_sinv_ptr
));
NVTE_CHECK
(
k
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 K-dim being divisble by %d (got %d)"
,
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i_
));
MXFP8_BLOCK_SIZE
,
k
);
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i_
));
size_t
sinv_k
=
k
/
MXFP8_BLOCK_SIZE
;
}
else
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
{
lhs_sinv_shape
[
0
]
=
m
;
lhs_sinv_shape
[
1
]
=
sinv_k
;
rhs_sinv_shape
[
0
]
=
n
;
rhs_sinv_shape
[
1
]
=
sinv_k
;
// Note: the scale_inv array should have been swizzled in Python before lowering
// Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper
lhs_i
(
NVTE_MXFP8_1D_SCALING
);
auto
lhs_sinv_shape_
=
lhs_sinv_i
.
dimensions
();
TensorWrapper
rhs_i
(
NVTE_MXFP8_1D_SCALING
);
auto
rhs_sinv_shape_
=
rhs_sinv_i
.
dimensions
();
lhs_i
.
set_rowwise_data
(
static_cast
<
void
*>
(
lhs_ptr
),
lhs_dtype
,
lhs_shape
);
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
rhs_i
.
set_rowwise_data
(
static_cast
<
void
*>
(
rhs_ptr
),
rhs_dtype
,
rhs_shape
);
lhs_sinv_shape
[
i
]
=
lhs_sinv_shape_
[
i
];
lhs_i
.
set_rowwise_scale_inv
(
static_cast
<
void
*>
(
lhs_sinv_ptr
),
DType
::
kFloat8E8M0
,
rhs_sinv_shape
[
i
]
=
rhs_sinv_shape_
[
i
];
lhs_sinv_shape
);
}
rhs_i
.
set_rowwise_scale_inv
(
static_cast
<
void
*>
(
rhs_sinv_ptr
),
DType
::
kFloat8E8M0
,
rhs_sinv_shape
);
NVTEScalingMode
nvte_scaling_mode
=
get_nvte_scaling_mode
(
scaling_mode
);
TensorWrapper
lhs_i_
(
nvte_scaling_mode
);
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
TensorWrapper
rhs_i_
(
nvte_scaling_mode
);
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
lhs_i_
.
set_rowwise_data
(
lhs_ptr
,
lhs_dtype
,
lhs_shape
);
rhs_i_
.
set_rowwise_data
(
rhs_ptr
,
rhs_dtype
,
rhs_shape
);
lhs_i_
.
set_rowwise_scale_inv
(
lhs_sinv_ptr
,
DType
::
kFloat8E8M0
,
lhs_sinv_shape
);
rhs_i_
.
set_rowwise_scale_inv
(
rhs_sinv_ptr
,
DType
::
kFloat8E8M0
,
rhs_sinv_shape
);
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i_
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i_
));
}
else
{
}
else
{
NVTE_ERROR
(
"Unsupported scaling mode: "
,
scaling_mode
);
NVTE_ERROR
(
"Unsupported scaling mode: "
,
static_cast
<
int
>
(
scaling_mode
)
)
;
}
}
auto
out_i
=
TensorWrapper
(
static_cast
<
void
*>
(
out_ptr
),
out_shape
,
out_dtype
);
auto
out_i_
=
TensorWrapper
(
out_ptr
,
out_shape
,
out_dtype
);
lhs_ptr
+=
m
*
k
*
lhs_dtype_bytes
;
rhs_ptr
+=
n
*
k
*
rhs_dtype_bytes
;
out_ptr
+=
m
*
n
*
out_dtype_bytes
;
lhs_sinv_ptr
+=
lhs_sinv_shape
[
0
]
*
lhs_sinv_shape
[
1
]
*
lhs_sinv_dtype_bytes
;
rhs_sinv_ptr
+=
rhs_sinv_shape
[
0
]
*
rhs_sinv_shape
[
1
]
*
rhs_sinv_dtype_bytes
;
void
*
pre_gelu_ptr
=
nullptr
;
void
*
pre_gelu_ptr
=
nullptr
;
auto
bias_shape
=
std
::
vector
<
size_t
>
{
0
};
auto
bias_shape
=
std
::
vector
<
size_t
>
{
0
};
auto
pre_gelu_shape
=
std
::
vector
<
size_t
>
{
0
};
auto
pre_gelu_shape
=
std
::
vector
<
size_t
>
{
0
};
if
(
bias_ptr
!=
nullptr
)
bias_shape
[
0
]
=
n
;
if
(
has_bias
)
{
auto
bias_i_get
=
input_list
.
get
<
Buffer_Type
>
(
bias_list_offset
+
i
);
Buffer_Type
bias_i
=
bias_i_get
.
value
();
bias_ptr
=
bias_i
.
untyped_data
();
bias_dtype
=
convert_ffi_datatype_to_te_dtype
(
bias_i
.
element_type
());
bias_shape
[
0
]
=
n
;
}
auto
bias_i
=
TensorWrapper
(
bias_ptr
,
bias_shape
,
bias_dtype
);
auto
bias_i
=
TensorWrapper
(
bias_ptr
,
bias_shape
,
bias_dtype
);
if
(
bias_ptr
!=
nullptr
)
bias_ptr
+=
n
*
bias_dtype_bytes
;
auto
pre_gelu_i
=
TensorWrapper
(
pre_gelu_ptr
,
pre_gelu_shape
,
out_dtype
);
auto
pre_gelu_i
=
TensorWrapper
(
pre_gelu_ptr
,
pre_gelu_shape
,
out_dtype
);
out_wrapper_list
.
push_back
(
std
::
move
(
out_i
));
out_wrapper_list
.
push_back
(
std
::
move
(
out_i
_
));
bias_wrapper_list
.
push_back
(
std
::
move
(
bias_i
));
bias_wrapper_list
.
push_back
(
std
::
move
(
bias_i
));
pre_gelu_wrapper_list
.
push_back
(
std
::
move
(
pre_gelu_i
));
pre_gelu_wrapper_list
.
push_back
(
std
::
move
(
pre_gelu_i
));
...
@@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
...
@@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
out_list
.
push_back
(
out_wrapper_list
.
back
().
data
());
out_list
.
push_back
(
out_wrapper_list
.
back
().
data
());
}
}
auto
workspace_get
=
output_list
.
get
<
Buffer_Type
>
(
num_gemms
);
Result_Type
workspace
=
workspace_get
.
value
();
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
->
untyped_data
());
size_t
workspace_size
=
workspace
->
dimensions
()[
0
]
/
num_streams
;
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
auto
workspace_i
=
auto
workspace_i
=
...
@@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
...
@@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
return
ffi_with_cuda_error_check
();
return
ffi_with_cuda_error_check
();
}
}
Error_Type
GroupedGemmFFI
(
cudaStream_t
stream
,
Buffer_Type
lhs_flatten
,
Buffer_Type
lhs_sinv_flatten
,
Buffer_Type
rhs_flatten
,
Buffer_Type
rhs_sinv_flatten
,
Buffer_Type
bias_flatten
,
Buffer_Type
dim_list
,
Result_Type
out_flatten
,
Result_Type
workspace_flatten
,
int64_t
num_gemms
,
int64_t
scaling_mode
)
{
// Inputs
auto
lhs_ptr
=
reinterpret_cast
<
uint8_t
*>
(
lhs_flatten
.
untyped_data
());
auto
rhs_ptr
=
reinterpret_cast
<
uint8_t
*>
(
rhs_flatten
.
untyped_data
());
auto
lhs_sinv_ptr
=
reinterpret_cast
<
uint8_t
*>
(
lhs_sinv_flatten
.
untyped_data
());
auto
rhs_sinv_ptr
=
reinterpret_cast
<
uint8_t
*>
(
rhs_sinv_flatten
.
untyped_data
());
auto
bias_ptr
=
reinterpret_cast
<
uint8_t
*>
(
bias_flatten
.
untyped_data
());
auto
dim_list_ptr
=
reinterpret_cast
<
int32_t
*>
(
dim_list
.
untyped_data
());
auto
lhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_flatten
.
element_type
());
auto
rhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_flatten
.
element_type
());
auto
lhs_sinv_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_sinv_flatten
.
element_type
());
auto
rhs_sinv_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_sinv_flatten
.
element_type
());
auto
bias_dtype
=
convert_ffi_datatype_to_te_dtype
(
bias_flatten
.
element_type
());
// Outputs
auto
out_ptr
=
reinterpret_cast
<
uint8_t
*>
(
out_flatten
->
untyped_data
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
out_flatten
->
element_type
());
auto
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace_flatten
->
untyped_data
());
auto
workspace_size
=
workspace_flatten
->
dimensions
().
back
()
/
num_streams
;
return
GroupedGemmImpl
(
lhs_ptr
,
lhs_dtype
,
lhs_sinv_ptr
,
lhs_sinv_dtype
,
rhs_ptr
,
rhs_dtype
,
rhs_sinv_ptr
,
rhs_sinv_dtype
,
bias_ptr
,
bias_dtype
,
out_ptr
,
out_dtype
,
workspace_ptr
,
workspace_size
,
num_gemms
,
dim_list_ptr
,
scaling_mode
,
stream
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
GroupedGemmHandler
,
GroupedGemmFFI
,
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
GroupedGemmHandler
,
GroupedGemmFFI
,
FFI
::
Bind
()
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// lhs_flatten
.
RemainingArgs
()
// input list
.
Arg
<
Buffer_Type
>
()
// lhs_sinv_flatten
.
RemainingRets
()
// output list
.
Arg
<
Buffer_Type
>
()
// rhs_flatten
.
Arg
<
Buffer_Type
>
()
// rhs_sinv_flatten
.
Arg
<
Buffer_Type
>
()
// bias_flatten
.
Arg
<
Buffer_Type
>
()
// dim_list
.
Ret
<
Buffer_Type
>
()
// out_flatten
.
Ret
<
Buffer_Type
>
()
// workspace_flatten
.
Attr
<
int64_t
>
(
"num_gemms"
)
.
Attr
<
int64_t
>
(
"num_gemms"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
),
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"has_bias"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
}
// namespace jax
}
// namespace jax
...
...
transformer_engine/jax/csrc/extensions/misc.h
View file @
ab3e5a92
...
@@ -34,11 +34,34 @@ inline size_t product(const std::vector<size_t> &shape) {
...
@@ -34,11 +34,34 @@ inline size_t product(const std::vector<size_t> &shape) {
return
ret
;
return
ret
;
}
}
enum
class
Quantize
Axis
{
enum
class
Quantize
Layout
{
ROWWISE
,
ROWWISE
,
COLWISE
,
COLWISE
,
ROWWISE_COLWISE
,
ROWWISE_COLWISE
,
};
};
enum
class
JAXX_Scaling_Mode
:
int64_t
{
NO_SCALING
=
0
,
DELAYED_TENSOR_SCALING
=
1
,
MXFP8_1D_SCALING
=
2
,
};
static
NVTEScalingMode
get_nvte_scaling_mode
(
const
JAXX_Scaling_Mode
&
mode
)
{
switch
(
mode
)
{
case
JAXX_Scaling_Mode
::
NO_SCALING
:
return
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
;
break
;
case
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
:
return
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
;
break
;
case
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
:
return
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
;
break
;
default:
NVTE_ERROR
(
"Invalid Scaling Mode "
,
static_cast
<
int
>
(
mode
));
break
;
}
}
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/normalization.cpp
View file @
ab3e5a92
...
@@ -14,7 +14,8 @@ namespace jax {
...
@@ -14,7 +14,8 @@ namespace jax {
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
DType
w_dtype
,
DType
out_dtype
,
NVTE_Norm_Type
norm_type
,
int
scaling_mode
,
NVTE_Norm_Type
norm_type
,
JAXX_Scaling_Mode
scaling_mode
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
is_training
)
{
bool
is_training
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
...
@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
...
@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
in_dtype
);
auto
rsigma_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
rsigma_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
_scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
auto
output_tensor
=
TensorWrapper
(
_scaling_mode
);
output_tensor
.
set_rowwise_data
(
nullptr
,
out_dtype
,
input_shape
);
output_tensor
.
set_rowwise_data
(
nullptr
,
out_dtype
,
input_shape
);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if
(
is_training
&&
_
scaling_mode
==
NVTE_
MXFP8_1D_SCALING
)
{
if
(
is_training
&&
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
{
int
temp
=
1
;
int
temp
=
1
;
output_tensor
.
set_columnwise_data
(
static_cast
<
void
*>
(
&
temp
),
out_dtype
,
input_shape
);
output_tensor
.
set_columnwise_data
(
static_cast
<
void
*>
(
&
temp
),
out_dtype
,
input_shape
);
}
}
...
@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
...
@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
nullptr
);
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
nullptr
);
}
else
{
}
else
{
NVTE_CHECK
(
scaling_mode
!=
NVTE
ScalingMode
::
NVTE_
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
NVTE_CHECK
(
scaling_mode
!=
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
"rmsnorm doesn't support zero_centered_gamma."
);
"rmsnorm doesn't support zero_centered_gamma."
);
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
epsilon
,
output_tensor
.
data
(),
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
epsilon
,
output_tensor
.
data
(),
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
...
@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
...
@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
int
norm_type
,
bool
zero_centered_gamma
,
double
epsilon
,
int
norm_type
,
bool
zero_centered_gamma
,
double
epsilon
,
int64_t
sm_margin
,
int
scaling_mode
,
bool
is_2x
)
{
int64_t
sm_margin
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
x_buf
.
element_type
());
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
x_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
(
gamma_buf
.
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
(
gamma_buf
.
element_type
());
...
@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
...
@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
*
workspace
=
wkspace_buf
->
untyped_data
();
auto
*
workspace
=
wkspace_buf
->
untyped_data
();
auto
_scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode
);
auto
_norm_type
=
static_cast
<
NVTE_Norm_Type
>
(
norm_type
);
auto
_norm_type
=
static_cast
<
NVTE_Norm_Type
>
(
norm_type
);
auto
_is_2x
=
static_cast
<
bool
>
(
is_2x
);
auto
_is_2x
=
static_cast
<
bool
>
(
is_2x
);
...
@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
...
@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
_sm_margin
;
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
_sm_margin
;
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
wkspace_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
wkspace_dtype
);
auto
output_tensor
=
TensorWrapper
(
_scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte
_scaling_mode
(
scaling_mode
)
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
input_shape
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
input_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
is_fp8_dtype
(
out_dtype
))
{
...
@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
...
@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
scale_inv_buf
->
dimensions
().
back
()});
scale_inv_buf
->
dimensions
().
back
()});
}
}
if
(
_
scaling_mode
==
NVTE_
DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
...
@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
...
@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
}
else
{
}
else
{
NVTE_CHECK
(
scaling_mode
!=
NVTE
ScalingMode
::
NVTE_
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
NVTE_CHECK
(
scaling_mode
!=
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
"rmsnorm doesn't support zero_centered_gamma."
);
"rmsnorm doesn't support zero_centered_gamma."
);
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
_epsilon
,
output_tensor
.
data
(),
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
_epsilon
,
output_tensor
.
data
(),
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
...
@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
...
@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"epsilon"
)
.
Attr
<
double
>
(
"epsilon"
)
.
Attr
<
int64_t
>
(
"sm_margin"
)
.
Attr
<
int64_t
>
(
"sm_margin"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
),
.
Attr
<
bool
>
(
"is_2x"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
...
...
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
ab3e5a92
...
@@ -138,17 +138,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
...
@@ -138,17 +138,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.
value
(
"RMSNorm"
,
NVTE_Norm_Type
::
RMSNorm
)
.
value
(
"RMSNorm"
,
NVTE_Norm_Type
::
RMSNorm
)
.
export_values
();
.
export_values
();
pybind11
::
enum_
<
NVTE
ScalingMode
>
(
m
,
"
NVTE
_Scaling_Mode"
,
pybind11
::
module_local
())
pybind11
::
enum_
<
JAXX_
Scaling
_
Mode
>
(
m
,
"
JAXX
_Scaling_Mode"
,
pybind11
::
module_local
())
.
value
(
"N
VTE_DELAYED_TENSOR
_SCALING"
,
NVTE
ScalingMode
::
N
VTE_DELAYED_TENSOR
_SCALING
)
.
value
(
"N
O
_SCALING"
,
JAXX_
Scaling
_
Mode
::
N
O
_SCALING
)
.
value
(
"
NVTE_MXFP8_1D
_SCALING"
,
NVTE
ScalingMode
::
NVTE_MXFP8_1D
_SCALING
)
.
value
(
"
DELAYED_TENSOR
_SCALING"
,
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR
_SCALING
)
.
value
(
"
NVTE_INVALI
D_SCALING"
,
NVTE
ScalingMode
::
NVTE_
MXFP8_1D_SCALING
)
.
value
(
"
MXFP8_1
D_SCALING"
,
JAXX_
Scaling
_
Mode
::
MXFP8_1D_SCALING
)
.
export_values
();
.
export_values
();
pybind11
::
enum_
<
transformer_engine
::
jax
::
Quantize
Axis
>
(
m
,
"Quantize
Axis
"
,
pybind11
::
enum_
<
transformer_engine
::
jax
::
Quantize
Layout
>
(
m
,
"Quantize
Layout
"
,
pybind11
::
module_local
())
pybind11
::
module_local
())
.
value
(
"ROWWISE"
,
transformer_engine
::
jax
::
Quantize
Axis
::
ROWWISE
)
.
value
(
"ROWWISE"
,
transformer_engine
::
jax
::
Quantize
Layout
::
ROWWISE
)
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
Quantize
Axis
::
COLWISE
)
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
Quantize
Layout
::
COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
Quantize
Axis
::
ROWWISE_COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
Quantize
Layout
::
ROWWISE_COLWISE
)
.
export_values
();
.
export_values
();
}
}
...
...
transformer_engine/jax/csrc/extensions/quantization.cpp
View file @
ab3e5a92
...
@@ -13,7 +13,9 @@ namespace transformer_engine {
...
@@ -13,7 +13,9 @@ namespace transformer_engine {
namespace
jax
{
namespace
jax
{
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
)
{
DType
in_dtype
,
DType
out_dtype
,
JAXX_Scaling_Mode
scaling_mode
,
QuantizeLayout
q_layout
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
hidden_size
,
batch_size
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
hidden_size
,
batch_size
};
...
@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
...
@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
int
temp
=
0
;
int
temp
=
0
;
auto
input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
input_shape
,
in_dtype
);
auto
input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
output_shape
,
out_dtype
);
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
q_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
||
q_layout
==
QuantizeLayout
::
ROWWISE
)
{
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
}
if
(
q_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
||
q_layout
==
QuantizeLayout
::
COLWISE
)
{
auto
&
tmp_shape
=
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
tmp_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_columnwise_scale_inv
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
}
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_amax
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
TensorWrapper
dummy_workspace
;
TensorWrapper
dummy_workspace
;
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
...
@@ -42,10 +71,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
...
@@ -42,10 +71,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type
DBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Error_Type
DBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
scale_inv_buf
,
Result_Type
trans
_scale_inv_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise
_scale_inv_buf
,
Result_Type
amax_
out_
buf
,
Result_Type
dbias_buf
,
Result_Type
amax_buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
Result_Type
workspace_buf
,
int64_t
scaling_mode
_enum
,
JAXX_Scaling_Mode
scaling_mode
,
int64_t
quantize_layout
_enum
,
int64_t
quantize_axis_enum
,
bool
is_dbia
s
)
{
bool
is_dbias
,
int64_t
flatten_axi
s
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
...
@@ -54,8 +83,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
...
@@ -54,8 +83,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
input
=
input_buf
.
untyped_data
();
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
const
quantize_layout
=
static_cast
<
QuantizeLayout
>
(
quantize_layout_enum
);
auto
const
quantize_axis
=
static_cast
<
QuantizeAxis
>
(
quantize_axis_enum
);
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
...
@@ -63,9 +91,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
...
@@ -63,9 +91,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void
*
workspace
=
workspace_buf
->
untyped_data
();
void
*
workspace
=
workspace_buf
->
untyped_data
();
auto
input_dims
=
input_buf
.
dimensions
();
auto
input_dims
=
input_buf
.
dimensions
();
int64_t
input_ndim
=
input_dims
.
size
();
if
(
flatten_axis
<
0
)
flatten_axis
+=
input_ndim
;
NVTE_CHECK
(
flatten_axis
<
input_ndim
&&
flatten_axis
>
0
,
"flatten_axis is out of bounds!"
);
auto
workspace_dims
=
workspace_buf
->
dimensions
();
auto
workspace_dims
=
workspace_buf
->
dimensions
();
auto
m
=
product
(
input_dims
,
0
,
input_dims
.
size
()
-
1
);
auto
m
=
product
(
input_dims
,
0
,
flatten_axis
);
auto
n
=
input_dims
.
back
(
);
auto
n
=
product
(
input_dims
,
flatten_axis
,
input_ndim
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
...
@@ -73,39 +105,58 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
...
@@ -73,39 +105,58 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
std
::
vector
<
size_t
>
workspace_shape
{
workspace_dims
.
begin
(),
workspace_dims
.
end
()};
std
::
vector
<
size_t
>
workspace_shape
{
workspace_dims
.
begin
(),
workspace_dims
.
end
()};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_
scaling_mode
(
scaling_mode
)
);
if
(
quantize_axis
==
QuantizeAxis
::
ROWWISE
||
quantize_axis
==
QuantizeAxis
::
ROWWISE_COLWISE
)
{
if
(
quantize_layout
==
QuantizeLayout
::
ROWWISE
||
quantize_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
)
{
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
}
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
is_fp8_dtype
(
out_dtype
))
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
NVTE_CHECK
(
amax_out
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax_out
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax_out
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
scale_inv_buf
->
dimensions
(),
flatten_axis
,
scale_inv_buf
->
dimensions
().
size
())});
}
}
}
}
if
(
quantize_axis
==
QuantizeAxis
::
COLWISE
||
quantize_axis
==
QuantizeAxis
::
ROWWISE_COLWISE
)
{
if
(
quantize_layout
==
QuantizeLayout
::
COLWISE
||
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
quantize_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
)
{
auto
&
tmp_shape
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
tmp_shape
);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
colwise_scale_inv_buf
=
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
trans_scale_inv_buf
;
?
scale_inv_buf
output_tensor
.
set_columnwise_scale_inv
(
:
colwise_scale_inv_buf
;
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
colwise_scale_inv_buf
->
dimensions
().
back
()});
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
tmp_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
tmp_buf
->
dimensions
(),
flatten_axis
,
tmp_buf
->
dimensions
().
size
())});
}
}
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
...
@@ -132,9 +183,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
...
@@ -132,9 +183,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"q_axis"
)
.
Attr
<
int64_t
>
(
"q_layout"
)
.
Attr
<
bool
>
(
"is_dbias"
),
.
Attr
<
bool
>
(
"is_dbias"
)
.
Attr
<
int64_t
>
(
"flatten_axis"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
Error_Type
DequantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Error_Type
DequantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
...
...
transformer_engine/jax/dense.py
View file @
ab3e5a92
...
@@ -15,7 +15,11 @@ import jax
...
@@ -15,7 +15,11 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
from
.quantize
import
QuantizerSet
,
noop_quantizer_set
from
.quantize
import
(
QuantizerSet
,
noop_quantizer_set
,
with_sharding_constraint_by_logical_axes
,
)
def
dense
(
def
dense
(
...
@@ -23,6 +27,8 @@ def dense(
...
@@ -23,6 +27,8 @@ def dense(
kernel
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
bias
:
jnp
.
ndarray
=
None
,
bias
:
jnp
.
ndarray
=
None
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_axes
:
Tuple
[
str
,
...]
=
None
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
):
):
"""Perform dense layer transformation with optional quantization.
"""Perform dense layer transformation with optional quantization.
...
@@ -48,12 +54,12 @@ def dense(
...
@@ -48,12 +54,12 @@ def dense(
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
output
+=
jnp
.
reshape
(
bias
,
bias_new_shape
)
output
+=
jnp
.
reshape
(
bias
,
bias_new_shape
)
else
:
else
:
output
=
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
)
output
=
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
)
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
4
,
5
))
def
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
):
def
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
):
"""Internal implementation of dense layer transformation with custom VJP.
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
This function implements the core dense layer transformation logic with support
...
@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
...
@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix
kernel: Weight matrix
bias: Optional bias tensor
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Returns:
Transformed output tensor
Transformed output tensor
"""
"""
output
,
_
=
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
)
output
,
_
=
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
)
return
output
return
output
def
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
):
def
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
):
"""Forward pass rule for dense layer transformation.
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Returns:
Tuple of (output, context) for backward pass
Tuple of (output, context) for backward pass
"""
"""
x_contracting_dims
,
k_contracting_dims
=
contracting_dims
x_contracting_dims
,
k_contracting_dims
=
contracting_dims
casted_x
=
tex
.
quantize
(
x
,
quantizer_set
.
x
)
flatten_axis_x
=
-
len
(
x_contracting_dims
)
casted_kernel
=
tex
.
quantize
(
kernel
,
quantizer_set
.
kernel
)
flatten_axis_k
=
len
(
k_contracting_dims
)
-
len
(
kernel
.
shape
)
casted_x
=
tex
.
quantize
(
x
,
flatten_axis
=
flatten_axis_x
,
quantizer
=
quantizer_set
.
x
)
casted_x
=
with_sharding_constraint_by_logical_axes
(
casted_x
,
input_axes
)
casted_kernel
=
tex
.
quantize
(
kernel
,
flatten_axis
=
flatten_axis_k
,
quantizer
=
quantizer_set
.
kernel
)
casted_kernel
=
with_sharding_constraint_by_logical_axes
(
casted_kernel
,
kernel_axes
)
# GEMM NN
# GEMM NN
output
=
tex
.
gemm
(
output
=
tex
.
gemm
(
...
@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
...
@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel
.
get_colwise_tensor
(),
casted_kernel
.
get_colwise_tensor
(),
(
x_contracting_dims
,
k_contracting_dims
),
(
x_contracting_dims
,
k_contracting_dims
),
)
)
use_bias
=
bias
is
not
None
use_bias
=
bias
is
not
None
if
use_bias
:
if
use_bias
:
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
...
@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel
.
shape
,
kernel
.
shape
,
use_bias
,
use_bias
,
quantizer_set
,
quantizer_set
,
flatten_axis_k
,
)
)
return
output
,
ctx
return
output
,
ctx
def
_dense_bwd_rule
(
contracting_dims
,
ctx
,
grad
):
# pylint: disable=unused-argument
def
_dense_bwd_rule
(
contracting_dims
,
input_axes
,
kernel_axes
,
ctx
,
grad
):
# pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Returns:
Tuple of gradients with respect to inputs
Tuple of gradients with respect to inputs
"""
"""
...
@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
...
@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape
,
kernel_shape
,
use_bias
,
use_bias
,
quantizer_set
,
quantizer_set
,
flatten_axis_k
,
)
=
ctx
)
=
ctx
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
flatten_axis
=
flatten_axis_k
,
quantizer
=
quantizer_set
.
dgrad
)
# GEMM NT
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...
@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
...
@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel
,
rowwise_casted_kernel
,
(
g_constracting_dim
,
k_constracting_dim
),
(
g_constracting_dim
,
k_constracting_dim
),
)
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
input_axes
)
# GEMM TN
# GEMM TN
# x_non_contracting_dims
# x_non_contracting_dims
...
@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
...
@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad
=
tex
.
gemm
(
wgrad
=
tex
.
gemm
(
colwise_casted_x
,
casted_grad
.
get_colwise_tensor
(),
(
x_constracting_dim
,
g_constracting_dim
)
colwise_casted_x
,
casted_grad
.
get_colwise_tensor
(),
(
x_constracting_dim
,
g_constracting_dim
)
)
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
return
dgrad
,
wgrad
,
dbias
,
quantizer_set
return
dgrad
,
wgrad
,
dbias
,
quantizer_set
...
...
transformer_engine/jax/flax/module.py
View file @
ab3e5a92
...
@@ -13,7 +13,6 @@ import jax.numpy as jnp
...
@@ -13,7 +13,6 @@ import jax.numpy as jnp
from
flax
import
linen
as
nn
from
flax
import
linen
as
nn
from
flax.linen
import
partitioning
as
nn_partitioning
from
flax.linen
import
partitioning
as
nn_partitioning
from
jax
import
lax
from
jax
import
lax
from
jax
import
nn
as
jax_nn
from
jax
import
random
as
jax_random
from
jax
import
random
as
jax_random
from
jax.ad_checkpoint
import
checkpoint_name
from
jax.ad_checkpoint
import
checkpoint_name
...
@@ -26,8 +25,14 @@ from ..layernorm_mlp import layernorm_mlp
...
@@ -26,8 +25,14 @@ from ..layernorm_mlp import layernorm_mlp
from
..activation
import
activation
from
..activation
import
activation
from
..softmax
import
softmax
,
SoftmaxType
from
..softmax
import
softmax
,
SoftmaxType
from
..sharding
import
with_sharding_constraint_by_logical_axes
from
..sharding
import
with_sharding_constraint_by_logical_axes
from
..cpp_extensions
import
is_softmax_kernel_available
from
..cpp_extensions
import
(
is_softmax_kernel_available
,
jax_scaled_softmax
,
jax_scaled_masked_softmax
,
jax_scaled_upper_triang_masked_softmax
,
)
from
..quantize
import
QuantizerFactory
,
QuantizeConfig
,
QuantizeMeta
,
QuantizeMetaSet
,
ScalingMode
from
..quantize
import
QuantizerFactory
,
QuantizeConfig
,
QuantizeMeta
,
QuantizeMetaSet
,
ScalingMode
from
..sharding
import
get_non_contracting_logical_axes
PRNGKey
=
Any
PRNGKey
=
Any
Shape
=
Tuple
[
int
,
...]
Shape
=
Tuple
[
int
,
...]
...
@@ -167,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -167,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype
=
inputs
.
dtype
input_dtype
=
inputs
.
dtype
logits
=
inputs
logits
=
inputs
if
self
.
softmax_type
is
not
SoftmaxType
.
SCALED
and
is_softmax_kernel_available
(
# use primitives
if
is_softmax_kernel_available
(
self
.
softmax_type
,
batch
,
heads
,
q_seqlen
,
k_seqlen
,
input_dtype
self
.
softmax_type
,
batch
,
heads
,
q_seqlen
,
k_seqlen
,
input_dtype
):
):
if
bias
is
not
None
:
if
bias
is
not
None
:
logits
=
logits
+
bias
.
astype
(
input_dtype
)
logits
=
logits
+
bias
.
astype
(
input_dtype
)
...
@@ -179,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -179,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
mask_
=
None
mask_
=
None
outputs
=
softmax
(
logits
,
mask_
,
self
.
scale_factor
,
self
.
softmax_type
)
outputs
=
softmax
(
logits
,
mask_
,
self
.
scale_factor
,
self
.
softmax_type
)
# use default jax based implementation
else
:
else
:
attention_bias
=
None
if
mask
is
not
None
:
attention_bias
=
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
-
1e10
),
jnp
.
full
(
mask
.
shape
,
0.0
),
)
attention_bias
=
attention_bias
.
astype
(
input_dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
attention_bias
=
_combine_biases
(
attention_bias
,
bias
)
logits
=
logits
+
bias
.
astype
(
input_dtype
)
if
attention_bias
is
not
None
:
logits
=
logits
+
attention_bias
.
astype
(
input_dtype
)
# For the case that
self.softmax
==
SoftmaxType.SCALED
_UPPER_TRIANG_MASKED
if
self
.
softmax
_type
is
SoftmaxType
.
SCALED
:
# and kernel is unavailable, then try on pure
scaled
softmax
custom calls.
outputs
=
jax_
scaled
_
softmax
(
logits
,
self
.
scale_factor
)
if
is_
softmax_
kernel_available
(
el
if
self
.
softmax_
type
is
SoftmaxType
.
SCALED_MASKED
:
SoftmaxType
.
SCALED
,
batch
,
heads
,
q_seqlen
,
k_seqlen
,
input_dtype
outputs
=
jax_scaled_masked_softmax
(
logits
,
mask
,
self
.
scale_factor
)
)
:
elif
self
.
softmax_type
is
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
outputs
=
softmax
(
logits
,
None
,
self
.
scale_factor
,
SoftmaxType
.
SCALED
)
outputs
=
jax_scaled_upper_triang_masked_
softmax
(
logits
,
self
.
scale_factor
)
else
:
else
:
outputs
=
jax_nn
.
softmax
(
logits
*
self
.
scale_factor
)
raise
ValueError
(
f
"Unsupported softmax type:
{
self
.
softmax_type
}
. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert
input_dtype
==
outputs
.
dtype
assert
input_dtype
==
outputs
.
dtype
return
outputs
return
outputs
...
@@ -360,7 +356,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
...
@@ -360,7 +356,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).
value
).
value
return
QuantizeMeta
(
scale
=
scale
,
amax_history
=
amax_history
)
return
QuantizeMeta
(
scale
=
scale
,
amax_history
=
amax_history
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
x_meta
=
generate_quantize_meta
(
"x"
)
x_meta
=
generate_quantize_meta
(
"x"
)
kernel_meta
=
generate_quantize_meta
(
"kernel"
)
kernel_meta
=
generate_quantize_meta
(
"kernel"
)
grad_meta
=
generate_quantize_meta
(
"grad"
)
grad_meta
=
generate_quantize_meta
(
"grad"
)
...
@@ -406,6 +402,10 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -406,6 +402,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
Optimization parameters
-----------------------
-----------------------
...
@@ -429,6 +429,7 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -429,6 +429,7 @@ class DenseGeneral(TransformerEngineBase):
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
False
transpose_batch_sequence
:
bool
=
False
input_axes
:
Tuple
[
str
,
...]
=
()
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
kernel_init
is
None
:
if
self
.
kernel_init
is
None
:
...
@@ -460,29 +461,35 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -460,29 +461,35 @@ class DenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
inputs
.
ndim
)
axis
=
_normalize_axes
(
axis
,
inputs
.
ndim
)
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
if
self
.
kernel_axes
:
assert
len
(
kernel_shape
)
==
len
(
self
.
kernel_axes
),
(
"Expected len(kernel_shape) to match len(kernel_axes),"
f
"got kernel_shape
{
kernel_shape
}
and kernel_axes
{
self
.
kernel_axes
}
"
)
kernel
=
nn_partitioning
.
param_with_axes
(
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
)
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
kernel
=
kernel
.
astype
(
input_dtype
)
kernel_compute_shape
=
(
reduce
(
operator
.
mul
,
[
inputs
.
shape
[
ax
]
for
ax
in
axis
],
1
),
reduce
(
operator
.
mul
,
features
,
1
),
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_compute_shape
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
)
).
astype
(
input_dtype
)
bias
=
bias
.
reshape
(
kernel_compute_shape
[
-
1
]).
astype
(
input_dtype
)
else
:
else
:
bias
=
None
bias
=
None
quantizer_set
=
self
.
generate_quantizer_set
()
quantizer_set
=
self
.
generate_quantizer_set
()
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
y
=
dense
(
y
=
dense
(
inputs
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
quantizer_set
inputs
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
input_axes
,
kernel_axes
=
self
.
kernel_axes
,
quantizer_set
=
quantizer_set
,
)
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
...
@@ -491,20 +498,14 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -491,20 +498,14 @@ class DenseGeneral(TransformerEngineBase):
*
features
[:
-
1
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
lora_a_kernel_init_shape
=
(
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_shape
)
kernel_compute_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
"lora_a_kernel"
,
"lora_a_kernel"
,
self
.
kernel_init
,
self
.
kernel_init
,
lora_a_kernel_
init_
shape
,
lora_a_kernel_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
lora_a_kernel_axes
,
axes
=
lora_a_kernel_axes
,
)
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
...
@@ -527,7 +528,6 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -527,7 +528,6 @@ class DenseGeneral(TransformerEngineBase):
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
assert
y
.
dtype
==
input_dtype
assert
y
.
dtype
==
input_dtype
y
=
y
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
features
)
return
y
return
y
...
@@ -678,6 +678,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -678,6 +678,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
"""
assert
self
.
axis
==
-
1
,
"Only support axis = =-1 at this moment"
input_dtype
=
inputs
.
dtype
input_dtype
=
inputs
.
dtype
ln_output
=
None
ln_output
=
None
...
@@ -692,10 +693,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -692,10 +693,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if
self
.
enable_layernorm
:
if
self
.
enable_layernorm
:
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
assert
self
.
axis
==
-
1
# Only support axis = =-1 at this moment
features
=
inputs
.
shape
[
-
1
]
features
=
inputs
.
shape
[
-
1
]
scale
,
ln_bias
=
_create_layernorm_parameters
(
scale
,
ln_bias
=
_create_layernorm_parameters
(
self
.
layernorm_type
,
self
.
layernorm_type
,
(
features
,),
(
features
,),
...
@@ -731,17 +729,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -731,17 +729,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_shape
=
tuple
(
y
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
)
+
features
kernel
=
nn_partitioning
.
param_with_axes
(
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
)
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
kernel
=
kernel
.
astype
(
input_dtype
)
kernel_compute_shape
=
(
reduce
(
operator
.
mul
,
[
inputs
.
shape
[
ax
]
for
ax
in
axis
],
1
),
reduce
(
operator
.
mul
,
features
,
1
),
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_compute_shape
)
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
...
@@ -756,11 +749,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -756,11 +749,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
dot_input_axes
=
self
.
dot_input_axes
,
dot_input_axes
=
self
.
dot_input_axes
,
kernel_axes
=
self
.
kernel_axes
,
quantizer_set
=
quantizer_set
,
quantizer_set
=
quantizer_set
,
)
)
else
:
else
:
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_input_axes
)
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_input_axes
)
z
=
dense
(
y
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
quantizer_set
)
z
=
dense
(
y
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
dot_input_axes
,
kernel_axes
=
self
.
kernel_axes
,
quantizer_set
=
quantizer_set
,
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
lora_a_kernel_shape
=
(
...
@@ -768,20 +769,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -768,20 +769,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*
features
[:
-
1
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
lora_a_kernel_init_shape
=
(
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_shape
)
kernel_compute_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
"lora_a_kernel"
,
"lora_a_kernel"
,
self
.
kernel_init
,
self
.
kernel_init
,
lora_a_kernel_
init_
shape
,
lora_a_kernel_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
lora_a_kernel_axes
,
axes
=
lora_a_kernel_axes
,
)
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
...
@@ -803,8 +798,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -803,8 +798,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
)
).
astype
(
input_dtype
)
bias
=
bias
.
reshape
(
kernel_compute_shape
[
-
1
]).
astype
(
input_dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
bias_shape
=
(
1
,)
*
(
z
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
bias_shape
=
(
1
,)
*
(
z
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
@@ -814,7 +808,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -814,7 +808,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z
=
z
/
self
.
depth_scaling
z
=
z
/
self
.
depth_scaling
assert
z
.
dtype
==
input_dtype
,
f
"output_dtype=
{
z
.
dtype
}
, input_dtype=
{
input_dtype
}
"
assert
z
.
dtype
==
input_dtype
,
f
"output_dtype=
{
z
.
dtype
}
, input_dtype=
{
input_dtype
}
"
z
=
z
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
features
)
#
z = z.reshape(*inputs.shape[: self.axis], *features)
return
z
,
ln_output
# dense_output, layer_norm_output
return
z
,
ln_output
# dense_output, layer_norm_output
...
@@ -989,6 +983,8 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -989,6 +983,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
"""
assert
self
.
axis
==
-
1
,
"Only support axis == -1 at this moment"
ffn1_quantizer_set
=
self
.
generate_quantizer_set
(
"_0"
)
ffn1_quantizer_set
=
self
.
generate_quantizer_set
(
"_0"
)
ffn2_quantizer_set
=
self
.
generate_quantizer_set
(
"_1"
)
ffn2_quantizer_set
=
self
.
generate_quantizer_set
(
"_1"
)
...
@@ -1027,7 +1023,6 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1027,7 +1023,6 @@ class LayerNormMLP(TransformerEngineBase):
)
)
# LayerNorm
# LayerNorm
if
self
.
enable_layernorm
:
if
self
.
enable_layernorm
:
assert
self
.
axis
==
-
1
# Only support axis == -1 at this moment
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
features
=
inputs
.
shape
[
-
1
]
features
=
inputs
.
shape
[
-
1
]
...
@@ -1071,7 +1066,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1071,7 +1066,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations
=
len
(
normalized_acts
)
num_activations
=
len
(
normalized_acts
)
axis
=
_canonicalize_tuple
(
self
.
axis
)
axis
=
_canonicalize_tuple
(
self
.
axis
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_1_each_shape
=
(
np
.
prod
([
y
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1_each_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1
=
nn_partitioning
.
param_with_axes
(
kernel_1
=
nn_partitioning
.
param_with_axes
(
"wi_kernel"
,
"wi_kernel"
,
kernel_1_init
,
kernel_1_init
,
...
@@ -1081,13 +1076,10 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1081,13 +1076,10 @@ class LayerNormMLP(TransformerEngineBase):
self
.
dtype
,
self
.
dtype
,
axes
=
self
.
kernel_axes_1
,
axes
=
self
.
kernel_axes_1
,
)
)
kernel_1_compute_shape
=
(
reduce
(
operator
.
mul
,
[
y
.
shape
[
ax
]
for
ax
in
axis
],
1
),
num_activations
*
self
.
intermediate_dim
,
)
kernel_1
=
jnp
.
reshape
(
kernel_1
,
kernel_1_compute_shape
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_1
=
kernel_1
.
astype
(
input_dtype
)
kernel_1
=
kernel_1
.
astype
(
input_dtype
)
hidden_size
=
inputs
.
shape
[
-
1
]
hidden_size
=
inputs
.
shape
[
-
1
]
hidden_size_tuple
=
_canonicalize_tuple
(
hidden_size
)
hidden_size_tuple
=
_canonicalize_tuple
(
hidden_size
)
kernel_2_shape
=
(
self
.
intermediate_dim
,)
+
hidden_size_tuple
kernel_2_shape
=
(
self
.
intermediate_dim
,)
+
hidden_size_tuple
...
@@ -1098,26 +1090,20 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1098,26 +1090,20 @@ class LayerNormMLP(TransformerEngineBase):
self
.
dtype
,
self
.
dtype
,
axes
=
self
.
kernel_axes_2
,
axes
=
self
.
kernel_axes_2
,
)
)
kernel_2_compute_shape
=
(
self
.
intermediate_dim
,
reduce
(
operator
.
mul
,
hidden_size_tuple
,
1
),
)
kernel_2
=
jnp
.
reshape
(
kernel_2
,
kernel_2_compute_shape
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_2
=
kernel_2
.
astype
(
input_dtype
)
kernel_2
=
kernel_2
.
astype
(
input_dtype
)
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_1_shape
=
num_activations
*
self
.
intermediate_dim
bias_1_shape
=
(
num_activations
,
self
.
intermediate_dim
)
bias_1
=
nn_partitioning
.
param_with_axes
(
bias_1
=
nn_partitioning
.
param_with_axes
(
"wi_bias"
,
"wi_bias"
,
self
.
bias_init
,
self
.
bias_init
,
bias_1_shape
,
bias_1_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
self
.
bias_axes_1
,
axes
=
self
.
bias_axes_1
,
)
).
astype
(
input_dtype
)
bias_1
=
bias_1
.
reshape
(
kernel_1_compute_shape
[
-
1
]).
astype
(
input_dtype
)
bias_2_shape
=
(
hidden_size
,)
bias_2_shape
=
(
hidden_size
,)
bias_2
=
nn_partitioning
.
param_with_axes
(
bias_2
=
nn_partitioning
.
param_with_axes
(
...
@@ -1126,8 +1112,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1126,8 +1112,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape
,
bias_2_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
self
.
bias_axes_2
,
axes
=
self
.
bias_axes_2
,
)
).
astype
(
input_dtype
)
bias_2
=
bias_2
.
reshape
(
kernel_2_compute_shape
[
-
1
]).
astype
(
input_dtype
)
else
:
else
:
bias_1
=
None
bias_1
=
None
bias_2
=
None
bias_2
=
None
...
@@ -1136,8 +1121,6 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1136,8 +1121,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name
=
"ffn2"
ffn2_ckpt_name
=
"ffn2"
if
use_fused_layernorm_mlp
:
if
use_fused_layernorm_mlp
:
assert
self
.
axis
==
-
1
# Only support axis = =-1 at this moment
out
=
layernorm_mlp
(
out
=
layernorm_mlp
(
y
,
y
,
scale
,
scale
,
...
@@ -1150,6 +1133,8 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1150,6 +1133,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes
=
self
.
layernorm_input_axes
,
norm_input_axes
=
self
.
layernorm_input_axes
,
dot_1_input_axes
=
self
.
dot_1_input_axes
,
dot_1_input_axes
=
self
.
dot_1_input_axes
,
dot_2_input_axes
=
self
.
dot_2_input_axes
,
dot_2_input_axes
=
self
.
dot_2_input_axes
,
kernel_1_axes
=
self
.
kernel_axes_1
,
kernel_2_axes
=
self
.
kernel_axes_2
,
ffn1_ckpt_name
=
ffn1_ckpt_name
,
ffn1_ckpt_name
=
ffn1_ckpt_name
,
ffn2_ckpt_name
=
ffn2_ckpt_name
,
ffn2_ckpt_name
=
ffn2_ckpt_name
,
activation_type
=
normalized_acts
,
activation_type
=
normalized_acts
,
...
@@ -1170,6 +1155,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1170,6 +1155,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
dot_input_axes
=
self
.
dot_1_input_axes
,
dot_input_axes
=
self
.
dot_1_input_axes
,
kernel_axes
=
self
.
kernel_axes_1
,
quantizer_set
=
ffn1_quantizer_set
,
quantizer_set
=
ffn1_quantizer_set
,
)
)
else
:
else
:
...
@@ -1178,35 +1164,31 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1178,35 +1164,31 @@ class LayerNormMLP(TransformerEngineBase):
y
,
y
,
kernel_1
,
kernel_1
,
contracting_dims
=
(
axis
,
contract_ind
),
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
dot_1_input_axes
,
kernel_axes
=
self
.
kernel_axes_1
,
quantizer_set
=
ffn1_quantizer_set
,
quantizer_set
=
ffn1_quantizer_set
,
)
)
dot_1_output_axes
=
(
*
get_non_contracting_logical_axes
(
y
.
ndim
,
self
.
dot_1_input_axes
,
axis
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
self
.
kernel_axes_1
,
contract_ind
),
)
x
=
with_sharding_constraint_by_logical_axes
(
x
,
dot_1_output_axes
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
wi_lora_a_kernel_shape
=
(
wi_lora_a_kernel_each_shape
=
(
kernel_1_compute_shape
[
0
],
kernel_1_each_shape
[:
len
(
axis
)],
num_activations
,
self
.
low_rank_adaptation_dim
,
)
wi_lora_a_kernel_init_shape
=
(
kernel_1_each_shape
[
0
],
num_activations
,
self
.
low_rank_adaptation_dim
,
)
wi_lora_a_kernel_init_each_shape
=
(
kernel_1_each_shape
[
0
],
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_
init
_shape
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_
each
_shape
+
1
)
wi_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
wi_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
"wi_lora_a_kernel"
,
"wi_lora_a_kernel"
,
kernel_1_init
,
kernel_1_init
,
num_activations
,
num_activations
,
-
1
,
-
2
,
wi_lora_a_kernel_
init_
each_shape
,
wi_lora_a_kernel_each_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
wi_lora_a_kernel_axes
,
axes
=
wi_lora_a_kernel_axes
,
)
)
wi_lora_a_kernel
=
jnp
.
reshape
(
wi_lora_a_kernel
,
wi_lora_a_kernel_shape
)
wi_lora_a_kernel
=
wi_lora_a_kernel
.
astype
(
input_dtype
)
wi_lora_a_kernel
=
wi_lora_a_kernel
.
astype
(
input_dtype
)
wi_lora_b_kernel_shape
=
(
wi_lora_b_kernel_shape
=
(
...
@@ -1227,7 +1209,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1227,7 +1209,7 @@ class LayerNormMLP(TransformerEngineBase):
x
+=
_apply_low_rank_adaptation
(
x
+=
_apply_low_rank_adaptation
(
y
,
y
,
axis
,
axis
,
num_activations
*
self
.
intermediate_dim
,
(
num_activations
,
self
.
intermediate_dim
)
,
wi_lora_a_kernel
,
wi_lora_a_kernel
,
wi_lora_b_kernel
,
wi_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
,
self
.
low_rank_adaptation_alpha
,
...
@@ -1241,11 +1223,12 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1241,11 +1223,12 @@ class LayerNormMLP(TransformerEngineBase):
z
=
activation
(
x
,
normalized_acts
)
z
=
activation
(
x
,
normalized_acts
)
else
:
else
:
activations
=
[]
activations
=
[]
x
=
jnp
.
split
(
x
,
num_activations
,
axis
=-
1
)
x
=
jnp
.
split
(
x
,
num_activations
,
axis
=-
2
)
for
idx
,
act_fn
in
enumerate
(
normalized_acts
):
for
idx
,
act_fn
in
enumerate
(
normalized_acts
):
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
activations
.
append
(
x_i
)
activations
.
append
(
x_i
)
z
=
reduce
(
operator
.
mul
,
activations
)
z
=
reduce
(
operator
.
mul
,
activations
)
z
=
jnp
.
squeeze
(
z
,
axis
=-
2
)
z
=
z
.
astype
(
input_dtype
)
z
=
z
.
astype
(
input_dtype
)
z
=
nn
.
Dropout
(
z
=
nn
.
Dropout
(
...
@@ -1259,7 +1242,12 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1259,7 +1242,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2
# DenseGeneral 2
out
=
dense
(
out
=
dense
(
z
,
kernel_2
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
ffn2_quantizer_set
z
,
kernel_2
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
dot_2_input_axes
,
kernel_axes
=
self
.
kernel_axes_2
,
quantizer_set
=
ffn2_quantizer_set
,
)
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
...
...
transformer_engine/jax/flax/transformer.py
View file @
ab3e5a92
...
@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
...
@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if
mask
is
not
None
:
if
mask
is
not
None
:
mask
=
apply_swa_mask
(
mask
)
mask
=
apply_swa_mask
(
mask
)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if
attn_mask_type
in
[
AttnMaskType
.
CAUSAL_MASK
,
AttnMaskType
.
PADDING_CAUSAL_MASK
]:
if
mask
is
not
None
:
return
SoftmaxType
.
SCALED_MASKED
,
mask
if
attn_mask_type
is
AttnMaskType
.
CAUSAL_MASK
:
return
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
mask
return
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
mask
if
attn_mask_type
in
[
AttnMaskType
.
NO_MASK
,
AttnMaskType
.
PADDING_MASK
]:
if
attn_mask_type
is
AttnMaskType
.
NO_MASK
:
if
mask
is
not
None
:
return
SoftmaxType
.
SCALED_MASKED
,
mask
return
SoftmaxType
.
SCALED
,
mask
return
SoftmaxType
.
SCALED
,
mask
raise
ValueError
(
raise
ValueError
(
f
"Unsupported
{
attn_mask_type
=
}
, supported attn_mask_type="
f
"Unsupported
{
attn_mask_type
=
}
, supported attn_mask_type="
...
@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
attn_bias_type: Optional[str], default = None
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...
...
transformer_engine/jax/layernorm_dense.py
View file @
ab3e5a92
...
@@ -33,10 +33,9 @@ def layernorm_dense(
...
@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type
:
str
=
"layernorm"
,
norm_type
:
str
=
"layernorm"
,
zero_centered_gamma
:
bool
=
False
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
epsilon
:
float
=
1e-6
,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes
:
Tuple
[
str
,
...]
=
None
,
layernorm_input_axes
:
Tuple
[
str
,
...]
=
None
,
# The logic axes of sharding constraint to the dot input.
dot_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_axes
:
Tuple
[
str
,
...]
=
None
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
)
->
jnp
.
ndarray
:
)
->
jnp
.
ndarray
:
"""Apply layer normalization followed by dense layer transformation.
"""Apply layer normalization followed by dense layer transformation.
...
@@ -56,6 +55,7 @@ def layernorm_dense(
...
@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types
quantizer_set: Set of quantizers for different tensor types
Returns:
Returns:
...
@@ -78,6 +78,7 @@ def layernorm_dense(
...
@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon
,
epsilon
,
layernorm_input_axes
,
layernorm_input_axes
,
dot_input_axes
,
dot_input_axes
,
kernel_axes
,
quantizer_set
,
quantizer_set
,
)
)
return
output
return
output
...
@@ -91,6 +92,7 @@ def layernorm_dense(
...
@@ -91,6 +92,7 @@ def layernorm_dense(
7
,
7
,
8
,
8
,
9
,
9
,
10
,
),
),
)
)
def
_layernorm_dense
(
def
_layernorm_dense
(
...
@@ -104,6 +106,7 @@ def _layernorm_dense(
...
@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon
:
float
,
epsilon
:
float
,
layernorm_input_axes
:
Tuple
[
str
,
...],
layernorm_input_axes
:
Tuple
[
str
,
...],
dot_input_axes
:
Tuple
[
str
,
...],
dot_input_axes
:
Tuple
[
str
,
...],
kernel_axes
:
Tuple
[
str
,
...],
quantizer_set
,
quantizer_set
,
):
):
"""Internal implementation of layernorm_dense with custom VJP.
"""Internal implementation of layernorm_dense with custom VJP.
...
@@ -139,6 +142,7 @@ def _layernorm_dense(
...
@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon
,
epsilon
,
layernorm_input_axes
,
layernorm_input_axes
,
dot_input_axes
,
dot_input_axes
,
kernel_axes
,
quantizer_set
,
quantizer_set
,
)
)
return
output
return
output
...
@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
...
@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon
,
epsilon
,
layernorm_input_axes
,
layernorm_input_axes
,
dot_input_axes
,
dot_input_axes
,
kernel_axes
,
quantizer_set
,
quantizer_set
,
):
):
"""Forward pass rule for layernorm_dense.
"""Forward pass rule for layernorm_dense.
...
@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
...
@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
k_contracting_dims
=
(
0
,)
k_contracting_dims
=
(
0
,)
assert
x
.
shape
[
-
1
]
==
kernel
.
shape
[
0
]
assert
x
.
shape
[
-
1
]
==
kernel
.
shape
[
0
]
assert
len
(
kernel
.
shape
)
==
2
# Otherwise need to merge dims in quantize
x
=
with_sharding_constraint_by_logical_axes
(
x
,
layernorm_input_axes
)
x
=
with_sharding_constraint_by_logical_axes
(
x
,
layernorm_input_axes
)
...
@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
...
@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type
,
norm_type
,
quantizer_set
.
x
,
quantizer_set
.
x
,
)
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_input_axes
)
# Kernel in (hidden_in, hidden_out...)
# Kernel in (hidden_in, hidden_out...)
casted_kernel
=
tex
.
quantize
(
kernel
,
quantizer_set
.
kernel
)
flatten_axis
=
1
-
len
(
kernel
.
shape
)
casted_kernel
=
tex
.
quantize
(
kernel
,
flatten_axis
=
flatten_axis
,
quantizer
=
quantizer_set
.
kernel
)
casted_
ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_
ln_out
,
dot_input
_axes
)
casted_
kernel
=
with_sharding_constraint_by_logical_axes
(
casted_
kernel
,
kernel
_axes
)
# NN GEMM
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
...
@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
...
@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims
,
k_contracting_dims
,
use_bias
,
use_bias
,
quantizer_set
,
quantizer_set
,
flatten_axis
,
)
)
return
output
,
ctx
return
output
,
ctx
...
@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
...
@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon
,
epsilon
,
layernorm_input_axes
,
layernorm_input_axes
,
dot_input_axes
,
# pylint: disable=unused-argument
dot_input_axes
,
# pylint: disable=unused-argument
kernel_axes
,
ctx
,
ctx
,
grad
,
grad
,
):
):
...
@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
...
@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd
,
k_contracting_dims_in_fwd
,
use_bias
,
use_bias
,
quantizer_set
,
quantizer_set
,
flatten_axis
,
)
=
ctx
)
=
ctx
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_input_axes
)
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
flatten_axis
=
flatten_axis
,
quantizer
=
quantizer_set
.
dgrad
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim
=
tuple
(
g_constracting_dim
=
tuple
(
...
@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
...
@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(
x_constracting_dim
,
g_constracting_dim
),
(
x_constracting_dim
,
g_constracting_dim
),
)
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dgrad
,
dgrad
,
x
,
x
,
...
...
transformer_engine/jax/layernorm_mlp.py
View file @
ab3e5a92
...
@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
...
@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
from
.layernorm
import
canonicalize_norm_type
from
.layernorm
import
canonicalize_norm_type
from
.quantize
import
with_sharding_constraint_by_logical_axes
,
QuantizerSet
,
noop_quantizer_set
from
.quantize
import
with_sharding_constraint_by_logical_axes
,
QuantizerSet
,
noop_quantizer_set
from
.sharding
import
get_non_contracting_logical_axes
def
layernorm_mlp
(
def
layernorm_mlp
(
...
@@ -37,6 +38,8 @@ def layernorm_mlp(
...
@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_1_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_2_axes
:
Tuple
[
str
,
...]
=
None
,
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
...
@@ -66,6 +69,8 @@ def layernorm_mlp(
...
@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
activation_type: Activation function(s) to apply after the first dense layer transformation
...
@@ -109,6 +114,8 @@ def layernorm_mlp(
...
@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
...
@@ -117,7 +124,7 @@ def layernorm_mlp(
...
@@ -117,7 +124,7 @@ def layernorm_mlp(
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
def
_layernorm_mlp
(
def
_layernorm_mlp
(
x
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
...
@@ -132,6 +139,8 @@ def _layernorm_mlp(
...
@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes
:
Tuple
[
str
,
...],
norm_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
kernel_1_axes
:
Tuple
[
str
,
...],
kernel_2_axes
:
Tuple
[
str
,
...],
ffn1_ckpt_name
:
str
,
ffn1_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
...
@@ -179,6 +188,8 @@ def _layernorm_mlp(
...
@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
...
@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
...
@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
...
@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
...
@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns:
Returns:
Tuple of (output, context) for automatic differentiation
Tuple of (output, context) for automatic differentiation
"""
"""
del
kernel_2_axes
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
# x should be in shape of (batch..., hidden)
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len
*
intermediate)
# Kernel_1 should be in shape of (hidden_in, activation_len
,
intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert
len
(
kernel_1
.
shape
)
==
2
assert
len
(
kernel_1
.
shape
)
==
3
assert
len
(
kernel_2
.
shape
)
==
2
assert
len
(
kernel_2
.
shape
)
==
2
assert
kernel_1
.
shape
[
1
]
==
kernel_2
.
shape
[
0
]
*
len
(
activation_type
)
assert
kernel_1
.
shape
[
-
2
]
==
len
(
activation_type
)
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
k_contracting_dims
=
(
0
,)
k_contracting_dims
=
(
0
,)
assert
x
.
shape
[
x_contracting_dims
[
0
]]
==
kernel_1
.
shape
[
k_contracting_dims
[
0
]]
assert
x
.
shape
[
x_contracting_dims
[
0
]]
==
kernel_1
.
shape
[
k_contracting_dims
[
0
]]
assert
kernel_1
.
shape
[
1
]
==
len
(
activation_type
)
*
kernel_2
.
shape
[
0
]
use_bias_1
=
bias_1
is
not
None
use_bias_1
=
bias_1
is
not
None
use_bias_2
=
bias_1
is
not
None
use_bias_2
=
bias_1
is
not
None
...
@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
...
@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type
,
norm_type
,
quantizer
=
ffn1_quantizer_set
.
x
,
quantizer
=
ffn1_quantizer_set
.
x
,
)
)
casted_kernel_1
=
tex
.
quantize
(
kernel_1
,
quantizer
=
ffn1_quantizer_set
.
kernel
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_1_input_axes
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_1_input_axes
)
casted_kernel_1
=
tex
.
quantize
(
kernel_1
,
flatten_axis
=-
2
,
quantizer
=
ffn1_quantizer_set
.
kernel
)
# NN GEMM
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output
=
tex
.
gemm
(
dot_1_output
=
tex
.
gemm
(
...
@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
...
@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1
.
get_colwise_tensor
(),
casted_kernel_1
.
get_colwise_tensor
(),
(
x_contracting_dims
,
k_contracting_dims
),
(
x_contracting_dims
,
k_contracting_dims
),
)
)
dot_1_output_axes
=
(
*
get_non_contracting_logical_axes
(
x
.
ndim
,
dot_1_input_axes
,
x_contracting_dims
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
kernel_1_axes
,
k_contracting_dims
),
)
dot_1_output
=
with_sharding_constraint_by_logical_axes
(
dot_1_output
,
dot_1_output_axes
)
if
use_bias_1
:
if
use_bias_1
:
bias_1_shape
=
bias_1
.
shape
bias_1_shape
=
bias_1
.
shape
bias_1_new_shape
=
(
1
,)
*
(
dot_1_output
.
ndim
-
bias_1
.
ndim
)
+
bias_1_shape
bias_1_new_shape
=
(
1
,)
*
(
dot_1_output
.
ndim
-
bias_1
.
ndim
)
+
bias_1_shape
...
@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
...
@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(
x_contracting_dims
,
k_contracting_dims
),
(
x_contracting_dims
,
k_contracting_dims
),
)
)
dot_2_output_axes
=
(
*
get_non_contracting_logical_axes
(
x
.
ndim
,
dot_2_input_axes
,
x_contracting_dims
),
*
get_non_contracting_logical_axes
(
kernel_2
.
ndim
,
None
,
k_contracting_dims
),
)
dot_2_output
=
with_sharding_constraint_by_logical_axes
(
dot_2_output
,
dot_2_output_axes
)
if
use_bias_2
:
if
use_bias_2
:
bias_2_shape
=
bias_2
.
shape
bias_2_shape
=
bias_2
.
shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
...
@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
...
@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
ffn1_ckpt_name
,
# pylint: disable=unused-argument
kernel_1_axes
,
ffn2_ckpt_name
,
# pylint: disable=unused-argument
kernel_2_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
ctx
,
ctx
,
grad
,
grad
,
...
@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
...
@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Returns:
Tuple of gradients for all input parameters
Tuple of gradients for all input parameters
"""
"""
del
norm_input_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
(
(
x
,
x
,
mu
,
mu
,
...
@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
...
@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
)
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_con
s
tracting_dim_2
=
tuple
(
g_contracting_dim
s
_2
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_2_shape
)
+
len
(
k_contracting_dims_in_fwd
),
grad
.
ndim
)
range
(
grad
.
ndim
-
len
(
kernel_2_shape
)
+
len
(
k_contracting_dims_in_fwd
),
grad
.
ndim
)
)
)
# k_non_contracting_dims
# k_non_contracting_dims
k_con
s
tracting_dim_2
=
tuple
(
k_contracting_dim
s
_2
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_2_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
dim
for
dim
in
range
(
len
(
kernel_2_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
)
)
...
@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
...
@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2
=
tex
.
gemm
(
dgrad_2
=
tex
.
gemm
(
casted_grad
.
get_rowwise_tensor
(),
casted_grad
.
get_rowwise_tensor
(),
rowwise_casted_kernel_2
,
rowwise_casted_kernel_2
,
(
g_con
s
tracting_dim_2
,
k_con
s
tracting_dim_2
),
(
g_contracting_dim
s
_2
,
k_contracting_dim
s
_2
),
)
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
x_con
s
tracting_dim
=
g_con
s
tracting_dim
=
tuple
(
x_contracting_dim
s
=
g_contracting_dim
s
=
tuple
(
range
(
0
,
len
(
x
.
shape
)
-
len
(
x_contracting_dims_in_fwd
))
range
(
0
,
len
(
x
.
shape
)
-
len
(
x_contracting_dims_in_fwd
))
)
)
...
@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
...
@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2
=
tex
.
gemm
(
wgrad_2
=
tex
.
gemm
(
colwise_casted_act_out
,
colwise_casted_act_out
,
casted_grad
.
get_colwise_tensor
(),
casted_grad
.
get_colwise_tensor
(),
(
x_con
s
tracting_dim
,
g_con
s
tracting_dim
),
(
x_contracting_dim
s
,
g_contracting_dim
s
),
)
)
wgrad_2
=
with_sharding_constraint_by_logical_axes
(
wgrad_2
,
kernel_2_axes
)
casted_dact_out
,
dbias_1
=
tex
.
quantize_dact_dbias
(
casted_dact_out
,
dbias_1
=
tex
.
quantize_dact_dbias
(
dgrad_2
,
dgrad_2
,
...
@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
...
@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
)
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1
=
tuple
(
dact_out_ndim
=
casted_dact_out
.
get_rowwise_tensor
().
data
.
ndim
range
(
dgrad_2
.
ndim
-
len
(
kernel_1_shape
)
+
len
(
k_contracting_dims_in_fwd
),
dgrad_2
.
ndim
)
g_contracting_dims_1
=
tuple
(
range
(
dact_out_ndim
-
len
(
kernel_1_shape
)
+
len
(
k_contracting_dims_in_fwd
),
dact_out_ndim
)
)
)
# k_non_contracting_dims
# k_non_contracting_dims
k_con
s
tracting_dim_1
=
tuple
(
k_contracting_dim
s
_1
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_1_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
dim
for
dim
in
range
(
len
(
kernel_1_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
)
)
...
@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
...
@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1
=
tex
.
gemm
(
dgrad_1
=
tex
.
gemm
(
casted_dact_out
.
get_rowwise_tensor
(),
casted_dact_out
.
get_rowwise_tensor
(),
rowwise_casted_kernel_1
,
rowwise_casted_kernel_1
,
(
g_con
s
tracting_dim_1
,
k_con
s
tracting_dim_1
),
(
g_contracting_dim
s
_1
,
k_contracting_dim
s
_1
),
)
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
norm
_input_axes
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
dot_1
_input_axes
)
# TN GEMM
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
# (hidden, batch...) x (hidden, batch...)
wgrad_1
=
tex
.
gemm
(
wgrad_1
=
tex
.
gemm
(
colwise_casted_ln_out
,
colwise_casted_ln_out
,
casted_dact_out
.
get_colwise_tensor
(),
casted_dact_out
.
get_colwise_tensor
(),
(
x_con
s
tracting_dim
,
g_con
s
tracting_dim
),
(
x_contracting_dim
s
,
g_contracting_dim
s
),
)
)
wgrad_1
=
with_sharding_constraint_by_logical_axes
(
wgrad_1
,
kernel_1_axes
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dgrad_1
,
dgrad_1
,
x
,
x
,
...
...
transformer_engine/jax/praxis/__init__.py
deleted
100644 → 0
View file @
a8d19fd9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from
.module
import
FusedSoftmax
,
LayerNorm
from
.module
import
LayerNormLinear
,
LayerNormMLP
,
Linear
,
TransformerEngineBaseLayer
from
.transformer
import
DotProductAttention
,
MultiHeadAttention
from
.transformer
import
RelativePositionBiases
,
TransformerLayer
from
..flax.transformer
import
TransformerLayerType
transformer_engine/jax/praxis/module.py
deleted
100644 → 0
View file @
a8d19fd9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from
dataclasses
import
field
from
functools
import
partial
from
typing
import
Callable
,
Iterable
,
Sequence
,
Tuple
,
Union
from
praxis
import
pax_fiddle
from
praxis.base_layer
import
init_var
from
praxis.base_layer
import
BaseLayer
,
WeightInit
,
WeightHParams
,
WeightHParamsCollection
from
praxis.layers
import
flax_adapter
from
praxis.pytypes
import
JTensor
from
..fp8
import
FP8Helper
from
..flax.module
import
DenseGeneral
,
LayerNormDenseGeneral
from
..flax.module
import
LayerNorm
as
flax_LayerNorm
from
..flax.module
import
LayerNormMLP
as
flax_LayerNormMLP
from
..flax.module
import
Softmax
from
..softmax
import
SoftmaxType
def
_generate_ln_scale_init
(
scale_init
):
if
scale_init
is
not
None
:
return
TransformerEngineBaseLayer
.
generate_params_init
(
"scale"
,
scale_init
)
return
scale_init
class
TransformerEngineBaseLayer
(
BaseLayer
):
"""TransformerEngineBaseLayer"""
logical_axes_rules
:
Tuple
[
Tuple
,
...]
=
None
@
staticmethod
def
generate_params_init
(
name
:
str
,
initializer
:
WeightInit
):
"""generate_params_init"""
def
kernel_init
(
key
,
shape
,
dtype
):
wp
=
WeightHParams
(
shape
=
shape
,
init
=
initializer
,
dtype
=
dtype
)
return
init_var
(
wp
,
key
,
name
)
return
kernel_init
def
create_layer
(
self
,
name
,
flax_module_cls
):
"""create_layer"""
fp8_collection_map
=
{
FP8Helper
.
FP8_COLLECTION_NAME
:
[
WeightHParamsCollection
.
SKIP_LP_REGULARIZATION
,
WeightHParamsCollection
.
OVERWRITE_WITH_GRADIENT
,
WeightHParamsCollection
.
DISALLOW_BFLOAT16_CONVERSION
,
]
}
flax_module_p
=
pax_fiddle
.
Config
(
flax_adapter
.
FlaxModuleAdapter
,
module_factory_method
=
flax_module_cls
,
logical_axes_rules
=
self
.
logical_axes_rules
,
var_collection_map
=
fp8_collection_map
,
ici_mesh_shape
=
self
.
ici_mesh_shape
,
dcn_mesh_shape
=
self
.
dcn_mesh_shape
,
mesh_axis_names
=
self
.
mesh_axis_names
,
)
self
.
create_child
(
name
,
flax_module_p
.
clone
())
class
LayerNorm
(
TransformerEngineBaseLayer
):
"""LayerNorm"""
epsilon
:
float
=
1e-6
layernorm_type
:
str
=
"layernorm"
zero_centered_gamma
:
bool
=
False
scale_init
:
WeightInit
=
None
scale_axes
:
Tuple
[
str
,
...]
=
()
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes
:
Tuple
[
str
,
...]
=
()
transpose_batch_sequence
:
bool
=
False
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
ln_cls
=
partial
(
flax_LayerNorm
,
epsilon
=
self
.
epsilon
,
layernorm_type
=
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
scale_init
=
_generate_ln_scale_init
(
self
.
scale_init
),
scale_axes
=
self
.
scale_axes
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
)
self
.
create_layer
(
"layer_norm"
,
ln_cls
)
def
__call__
(
self
,
x
:
JTensor
)
->
JTensor
:
"""__call__"""
return
self
.
layer_norm
(
x
)
class
FusedSoftmax
(
TransformerEngineBaseLayer
):
"""FusedSoftmax"""
scale_factor
:
float
=
1.0
softmax_type
:
SoftmaxType
=
SoftmaxType
.
SCALED
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
fused_softmax_cls
=
partial
(
Softmax
,
scale_factor
=
self
.
scale_factor
,
softmax_type
=
self
.
softmax_type
)
self
.
create_layer
(
"fused_softmax"
,
fused_softmax_cls
)
def
__call__
(
self
,
x
:
JTensor
,
mask
:
JTensor
=
None
,
bias
:
JTensor
=
None
)
->
JTensor
:
"""__call__"""
return
self
.
fused_softmax
(
x
,
mask
,
bias
)
class
Linear
(
TransformerEngineBaseLayer
):
"""Linear"""
out_features
:
int
=
512
kernel_axes
:
Tuple
[
str
,
...]
=
()
use_bias
:
bool
=
True
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
dense_general_cls
=
partial
(
DenseGeneral
,
features
=
self
.
out_features
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
kernel_axes
=
self
.
kernel_axes
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
)
self
.
create_layer
(
"linear"
,
dense_general_cls
)
def
__call__
(
self
,
x
:
JTensor
)
->
JTensor
:
"""__call__"""
return
self
.
linear
(
x
)
class
LayerNormLinear
(
TransformerEngineBaseLayer
):
"""LayerNormLinear"""
out_features
:
int
=
512
enable_layernorm
:
bool
=
True
layernorm_type
:
str
=
"layernorm"
epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
scale_init
:
WeightInit
=
None
scale_axes
:
Tuple
[
str
,
...]
=
()
ln_bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
1.0
)
)
ln_bias_axes
:
Tuple
[
str
,
...]
=
()
kernel_axes
:
Tuple
[
str
,
...]
=
()
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
depth_scaling
:
float
=
None
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
ln_dense_general_cls
=
partial
(
LayerNormDenseGeneral
,
features
=
self
.
out_features
,
enable_layernorm
=
self
.
enable_layernorm
,
layernorm_type
=
self
.
layernorm_type
,
epsilon
=
self
.
epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
scale_init
=
_generate_ln_scale_init
(
self
.
scale_init
),
scale_axes
=
self
.
scale_axes
,
ln_bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
self
.
ln_bias_init
),
ln_bias_axes
=
self
.
ln_bias_axes
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
kernel_axes
=
self
.
kernel_axes
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
depth_scaling
=
self
.
depth_scaling
,
)
self
.
create_layer
(
"ln_linear"
,
ln_dense_general_cls
)
def
__call__
(
self
,
x
:
JTensor
)
->
JTensor
:
"""__call__"""
return
self
.
ln_linear
(
x
)
class
LayerNormMLP
(
TransformerEngineBaseLayer
):
"""LayerNormMLP"""
intermediate_dim
:
int
=
2048
enable_layernorm
:
bool
=
True
layernorm_type
:
str
=
"layernorm"
epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
scale_init
:
WeightInit
=
None
scale_axes
:
Tuple
[
str
,
...]
=
()
ln_bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
1.0
)
)
ln_bias_axes
:
Tuple
[
str
,
...]
=
()
kernel_axes_1
:
Tuple
[
str
,
...]
=
()
kernel_axes_2
:
Tuple
[
str
,
...]
=
()
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes_1
:
Tuple
[
str
,
...]
=
()
bias_axes_2
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
activations
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"relu"
,)
intermediate_dropout_rate
:
float
=
0.1
intermediate_hidden_dropout_dims
:
Sequence
[
int
]
=
()
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
ln_mlp_cls
=
partial
(
flax_LayerNormMLP
,
intermediate_dim
=
self
.
intermediate_dim
,
enable_layernorm
=
self
.
enable_layernorm
,
layernorm_type
=
self
.
layernorm_type
,
epsilon
=
self
.
epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
scale_init
=
_generate_ln_scale_init
(
self
.
scale_init
),
scale_axes
=
self
.
scale_axes
,
ln_bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
self
.
ln_bias_init
),
ln_bias_axes
=
self
.
ln_bias_axes
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
kernel_axes_1
=
self
.
kernel_axes_1
,
kernel_axes_2
=
self
.
kernel_axes_2
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes_1
=
self
.
bias_axes_1
,
bias_axes_2
=
self
.
bias_axes_2
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
activations
=
self
.
activations
,
intermediate_dropout_rate
=
self
.
intermediate_dropout_rate
,
intermediate_hidden_dropout_dims
=
self
.
intermediate_hidden_dropout_dims
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
)
self
.
create_layer
(
"ln_mlp"
,
ln_mlp_cls
)
def
__call__
(
self
,
x
:
JTensor
,
deterministic
:
bool
=
False
)
->
JTensor
:
"""__call__"""
return
self
.
ln_mlp
(
x
,
deterministic
)
transformer_engine/jax/praxis/transformer.py
deleted
100644 → 0
View file @
a8d19fd9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from
dataclasses
import
field
from
functools
import
partial
from
typing
import
Optional
,
Sequence
,
Tuple
import
warnings
from
praxis
import
pax_fiddle
from
praxis.base_layer
import
WeightInit
from
praxis.pytypes
import
JTensor
from
.module
import
TransformerEngineBaseLayer
from
..flax.transformer
import
TransformerLayerType
from
..flax.transformer
import
DotProductAttention
as
flax_DotProductAttention
from
..flax.transformer
import
MultiHeadAttention
as
flax_MultiHeadAttention
from
..flax.transformer
import
RelativePositionBiases
as
flax_RelativePositionBiases
from
..flax.transformer
import
TransformerLayer
as
flax_TransformerLayer
from
..attention
import
AttnBiasType
,
AttnMaskType
class
RelativePositionBiases
(
TransformerEngineBaseLayer
):
"""RelativePositionBiases"""
num_buckets
:
int
=
32
max_distance
:
int
=
128
num_attention_heads
:
int
=
64
embedding_init
:
WeightInit
=
None
embedding_axes
:
Tuple
[
str
,
...]
=
()
@
staticmethod
def
generate_embedding_init
(
init
,
num_attention_heads
,
num_buckets
):
"""generate_embedding_init"""
embedding_init
=
init
if
embedding_init
is
None
:
rb_stddev
=
(
num_attention_heads
*
num_buckets
)
**
-
0.5
embedding_init
=
WeightInit
.
Gaussian
(
rb_stddev
)
return
embedding_init
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
embedding_init
=
RelativePositionBiases
.
generate_embedding_init
(
self
.
embedding_init
,
self
.
num_attention_heads
,
self
.
num_buckets
)
rpb_cls
=
partial
(
flax_RelativePositionBiases
,
num_buckets
=
self
.
num_buckets
,
max_distance
=
self
.
max_distance
,
num_attention_heads
=
self
.
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
embedding_init
),
embedding_axes
=
self
.
embedding_axes
,
dtype
=
self
.
dtype
,
)
self
.
create_layer
(
"relative_position_bias"
,
rpb_cls
)
def
__call__
(
self
,
q_seqlen
:
JTensor
,
k_seqlen
:
JTensor
,
bidirectional
:
bool
=
True
)
->
JTensor
:
"""__call__"""
return
self
.
relative_position_bias
(
q_seqlen
,
k_seqlen
,
bidirectional
)
class
DotProductAttention
(
TransformerEngineBaseLayer
):
"""DotProductAttention"""
head_dim
:
int
=
0
num_attention_heads
:
int
=
0
num_gqa_groups
:
Optional
[
int
]
=
None
attention_dropout
:
float
=
0.0
attn_mask_type
:
AttnMaskType
=
"causal"
attn_bias_type
:
AttnBiasType
=
None
dropout_rng_name
:
str
=
"dropout"
float32_logits
:
bool
=
False
qkv_layout
:
str
=
"bshd_bshd_bshd"
scale_factor
:
Optional
[
float
]
=
None
transpose_batch_sequence
:
bool
=
True
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
assert
self
.
head_dim
>
0
,
f
"
{
self
.
head_dim
=
}
"
assert
self
.
num_attention_heads
>
0
,
f
"
{
self
.
num_attention_heads
=
}
"
dpa_cls
=
partial
(
flax_DotProductAttention
,
head_dim
=
self
.
head_dim
,
num_attention_heads
=
self
.
num_attention_heads
,
num_gqa_groups
=
self
.
num_gqa_groups
,
attn_mask_type
=
self
.
attn_mask_type
,
attn_bias_type
=
self
.
attn_bias_type
,
attention_dropout
=
self
.
attention_dropout
,
dtype
=
self
.
dtype
,
dropout_rng_name
=
self
.
dropout_rng_name
,
float32_logits
=
self
.
float32_logits
,
qkv_layout
=
self
.
qkv_layout
,
scale_factor
=
self
.
scale_factor
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
window_size
=
self
.
window_size
,
)
self
.
create_layer
(
"dot_product_attention"
,
dpa_cls
)
def
__call__
(
self
,
query
:
JTensor
,
key
:
JTensor
,
value
:
JTensor
,
mask
:
Optional
[
JTensor
]
=
None
,
bias
:
Optional
[
JTensor
]
=
None
,
*
,
deterministic
:
bool
=
False
,
)
->
JTensor
:
"""__call__"""
return
self
.
dot_product_attention
(
query
,
key
,
value
,
mask
,
bias
,
deterministic
=
deterministic
)
class
MultiHeadAttention
(
TransformerEngineBaseLayer
):
"""MultiHeadAttention"""
head_dim
:
int
=
0
num_attention_heads
:
int
=
0
num_gqa_groups
:
Optional
[
int
]
=
None
attention_dropout
:
float
=
0.0
dropout_rng_name
:
str
=
"dropout"
input_layernorm
:
bool
=
True
layernorm_type
:
str
=
"layernorm"
layernorm_epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
return_layernorm_output
:
bool
=
False
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
attn_mask_type
:
str
=
"causal"
attn_bias_type
:
Optional
[
str
]
=
None
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
"consecutive"
low_rank_adaptation_scope
:
str
=
"none"
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
enable_sequence_parallel
:
bool
=
False
scale_attn_logits
:
bool
=
False
scaled_query_init
:
bool
=
True
float32_logits
:
bool
=
False
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Deprecated parameters
num_heads
:
Optional
[
int
]
=
None
dropout_rate
:
Optional
[
float
]
=
None
output_layernorm
:
Optional
[
bool
]
=
None
apply_residual_connection_post_layernorm
:
Optional
[
bool
]
=
None
fuse_qkv
:
Optional
[
bool
]
=
None
def
__post_init__
(
self
):
# Deal with the deprecated parameters
if
self
.
num_heads
is
not
None
:
self
.
num_attention_heads
=
self
.
num_heads
warnings
.
warn
(
f
"
{
__class__
}
.num_heads is deprecated. It will be removed recently. "
f
"Please uses
{
__class__
}
.num_attention_heads as the new API."
,
DeprecationWarning
,
)
if
self
.
dropout_rate
is
not
None
:
self
.
attention_dropout
=
self
.
dropout_rate
warnings
.
warn
(
f
"
{
__class__
}
.dropout_rate is deprecated. It will be removed recently. "
f
"Please use
{
__class__
}
.attention_dropout as the new API."
,
DeprecationWarning
,
)
if
self
.
apply_residual_connection_post_layernorm
is
not
None
:
warnings
.
warn
(
f
"
{
__class__
}
.apply_residual_connection_post_layernorm is deprecated. "
f
"It will be removed recently, please use
{
__class__
}
.return_layernorm_output."
,
DeprecationWarning
,
)
if
self
.
fuse_qkv
is
not
None
:
warnings
.
warn
(
f
"
{
__class__
}
.fuse_qkv is deprecated. It will be removed recently. "
f
"Please use
{
__class__
}
.fuse_qkv_params as the new API."
,
DeprecationWarning
,
)
assert
self
.
output_layernorm
is
None
,
(
f
"
{
__class__
}
.output_layernorm is deprecated. It will be removed recently. "
f
"Please use
{
__class__
}
.input_layernorm for controlling whether to apply layernorm."
)
if
self
.
num_gqa_groups
is
None
:
self
.
num_gqa_groups
=
self
.
num_heads
super
().
__post_init__
()
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
assert
self
.
head_dim
>
0
,
f
"
{
self
.
head_dim
=
}
"
assert
self
.
num_attention_heads
>
0
,
f
"
{
self
.
num_attention_heads
=
}
"
mha_cls
=
partial
(
flax_MultiHeadAttention
,
dtype
=
self
.
dtype
,
head_dim
=
self
.
head_dim
,
num_attention_heads
=
self
.
num_attention_heads
,
num_gqa_groups
=
self
.
num_gqa_groups
,
attention_dropout
=
self
.
attention_dropout
,
dropout_rng_name
=
self
.
dropout_rng_name
,
input_layernorm
=
self
.
input_layernorm
,
layernorm_type
=
self
.
layernorm_type
,
layernorm_epsilon
=
self
.
layernorm_epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
return_layernorm_output
=
self
.
return_layernorm_output
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
attn_mask_type
=
self
.
attn_mask_type
,
attn_bias_type
=
self
.
attn_bias_type
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scaled_query_init
=
self
.
scaled_query_init
,
float32_logits
=
self
.
float32_logits
,
window_size
=
self
.
window_size
,
)
self
.
create_layer
(
"multi_head_attn"
,
mha_cls
)
def
__call__
(
self
,
inputs_q
:
JTensor
,
inputs_kv
:
JTensor
,
mask
:
Optional
[
JTensor
]
=
None
,
bias
:
Optional
[
JTensor
]
=
None
,
*
,
decode
:
bool
=
False
,
deterministic
:
bool
=
False
,
)
->
JTensor
:
"""__call__"""
return
self
.
multi_head_attn
(
inputs_q
,
inputs_kv
,
mask
,
bias
,
decode
=
decode
,
deterministic
=
deterministic
)
class
TransformerLayer
(
TransformerEngineBaseLayer
):
"""TransformerLayer"""
hidden_size
:
int
=
512
mlp_hidden_size
:
int
=
2048
num_attention_heads
:
int
=
8
num_gqa_groups
:
Optional
[
int
]
=
None
layernorm_type
:
str
=
"layernorm"
layernorm_epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
hidden_dropout
:
float
=
0.1
hidden_dropout_dims
:
Sequence
[
int
]
=
()
attention_dropout
:
float
=
0.1
intermediate_dropout
:
float
=
0.1
intermediate_dropout_dims
:
Sequence
[
int
]
=
()
dropout_rng_name
:
str
=
"dropout"
mlp_activations
:
Sequence
[
str
]
=
(
"relu"
,)
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
apply_residual_connection_post_layernorm
:
bool
=
False
output_layernorm
:
bool
=
False
float32_attention_logits
:
bool
=
False
layer_type
:
TransformerLayerType
=
TransformerLayerType
.
ENCODER
self_attn_mask_type
:
str
=
"causal"
self_attn_bias_type
:
Optional
[
str
]
=
None
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
"consecutive"
low_rank_adaptation_scope
:
str
=
"none"
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
enable_relative_embedding
:
bool
=
True
relative_embedding
:
pax_fiddle
.
Config
[
RelativePositionBiases
]
=
pax_fiddle
.
template_field
(
None
)
drop_path
:
float
=
0.0
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
False
enable_sequence_parallel
:
bool
=
False
scale_attn_logits
:
bool
=
False
scaled_query_init
:
bool
=
True
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
def
__post_init__
(
self
):
if
self
.
num_gqa_groups
is
None
:
self
.
num_gqa_groups
=
self
.
num_attention_heads
super
().
__post_init__
()
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
relative_embedding_flax_module
=
None
if
self
.
enable_relative_embedding
and
self
.
relative_embedding
is
not
None
:
assert
self
.
relative_embedding
.
num_attention_heads
==
self
.
num_attention_heads
,
(
"TransformerLayer.relative_embedding.num_attention_heads shoule be"
"the same as TransformerLayer.num_attention_heads."
)
embedding_init
=
RelativePositionBiases
.
generate_embedding_init
(
self
.
relative_embedding
.
embedding_init
,
self
.
relative_embedding
.
num_attention_heads
,
self
.
relative_embedding
.
num_buckets
,
)
relative_embedding_flax_module
=
flax_RelativePositionBiases
(
num_buckets
=
self
.
relative_embedding
.
num_buckets
,
max_distance
=
self
.
relative_embedding
.
max_distance
,
num_attention_heads
=
self
.
relative_embedding
.
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
embedding_init
),
embedding_axes
=
self
.
relative_embedding
.
embedding_axes
,
dtype
=
self
.
relative_embedding
.
dtype
,
)
transformerlayer_cls
=
partial
(
flax_TransformerLayer
,
dtype
=
self
.
dtype
,
hidden_size
=
self
.
hidden_size
,
mlp_hidden_size
=
self
.
mlp_hidden_size
,
num_attention_heads
=
self
.
num_attention_heads
,
num_gqa_groups
=
self
.
num_gqa_groups
,
layernorm_type
=
self
.
layernorm_type
,
layernorm_epsilon
=
self
.
layernorm_epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
hidden_dropout
=
self
.
hidden_dropout
,
hidden_dropout_dims
=
self
.
hidden_dropout_dims
,
attention_dropout
=
self
.
attention_dropout
,
intermediate_dropout
=
self
.
intermediate_dropout
,
intermediate_dropout_dims
=
self
.
intermediate_dropout_dims
,
dropout_rng_name
=
self
.
dropout_rng_name
,
mha_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mha_kernel"
,
self
.
params_init
),
mlp_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mlp_kernel"
,
self
.
params_init
),
mlp_activations
=
self
.
mlp_activations
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
apply_residual_connection_post_layernorm
=
self
.
apply_residual_connection_post_layernorm
,
output_layernorm
=
self
.
output_layernorm
,
float32_attention_logits
=
self
.
float32_attention_logits
,
layer_type
=
self
.
layer_type
,
self_attn_mask_type
=
self
.
self_attn_mask_type
,
self_attn_bias_type
=
self
.
self_attn_bias_type
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
enable_relative_embedding
=
self
.
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
drop_path
=
self
.
drop_path
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scaled_query_init
=
self
.
scaled_query_init
,
window_size
=
self
.
window_size
,
)
self
.
create_layer
(
"transformerlayer"
,
transformerlayer_cls
)
def
__call__
(
self
,
inputs
:
JTensor
,
encoded
:
JTensor
=
None
,
attention_mask
:
JTensor
=
None
,
encoder_decoder_mask
:
JTensor
=
None
,
deterministic
:
bool
=
False
,
decode
:
bool
=
False
,
max_decode_length
:
bool
=
None
,
)
->
JTensor
:
"""__call__"""
return
self
.
transformerlayer
(
inputs
,
encoded
,
attention_mask
,
encoder_decoder_mask
,
deterministic
,
decode
,
max_decode_length
,
)
transformer_engine/jax/quantize/dequantizer.py
View file @
ab3e5a92
...
@@ -57,26 +57,35 @@ class Dequantizer:
...
@@ -57,26 +57,35 @@ class Dequantizer:
data
=
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
data
=
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
data_shape
=
data
.
shape
data_shape
=
data
.
shape
scale
=
scaled_tensor
.
scale_inv
.
view
(
jnp
.
uint8
).
astype
(
jnp
.
float32
)
scale
=
scaled_tensor
.
scale_inv
.
view
(
jnp
.
uint8
).
astype
(
jnp
.
float32
)
flatten_axis
=
scaled_tensor
.
flatten_axis
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
assert
(
0
<
flatten_axis
<
len
(
data_shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
scale_shape
=
scaled_tensor
.
scaling_mode
.
get_scale_shape
(
scale_shape
=
scaled_tensor
.
scaling_mode
.
get_scale_shape
(
scaled_tensor
.
data
.
shape
,
scaled_tensor
.
is_colwise
,
is_padded
=
False
data
_
shape
,
scaled_tensor
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
)
scale
=
jax
.
lax
.
slice
(
scale
,
[
0
]
*
len
(
scale_shape
),
scale_shape
)
# slice out the padding
scale
=
jax
.
lax
.
slice
(
scale
,
[
0
]
*
len
(
scale_shape
),
scale_shape
)
# slice out the padding
data
=
data
.
reshape
(
data
=
data
.
reshape
(
*
data_shape
[:
-
2
],
*
data_shape
[:
flatten_axis
-
1
],
scale_shape
[
-
2
],
scale_shape
[
flatten_axis
-
1
],
int
(
data_shape
[
-
2
]
/
scale_shape
[
-
2
]),
int
(
data_shape
[
flatten_axis
-
1
]
/
scale_shape
[
flatten_axis
-
1
]),
*
data_shape
[
flatten_axis
:
-
1
],
scale_shape
[
-
1
],
scale_shape
[
-
1
],
int
(
data_shape
[
-
1
]
/
scale_shape
[
-
1
]),
int
(
data_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
)
scale
=
jnp
.
expand_dims
(
scale
,
axis
=
(
-
1
,
-
3
))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale
=
jnp
.
expand_dims
(
scale
,
axis
=
(
flatten_axis
+
2
-
2
,
-
1
))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return
jnp
.
asarray
(
data
*
jnp
.
power
(
2
,
scale
-
127
),
scaled_tensor
.
dq_dtype
).
reshape
(
return
jnp
.
asarray
(
data
*
jnp
.
power
(
2
,
scale
-
127
),
scaled_tensor
.
dq_dtype
).
reshape
(
data_shape
data_shape
)
)
funcs
=
{
funcs
=
{
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
ScalingMode
.
MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
}
}
@
staticmethod
@
staticmethod
...
...
transformer_engine/jax/quantize/helper.py
View file @
ab3e5a92
...
@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource
...
@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from
.scaling_modes
import
ScalingMode
from
.scaling_modes
import
ScalingMode
from
..
import
cpp_extensions
as
tex
from
..
import
cpp_extensions
as
tex
__all__
=
[
"QuantizeConfig"
,
"fp8_autocast"
,
"is_fp8_available"
,
"update_collections"
]
__all__
=
[
"QuantizeConfig"
,
"fp8_autocast"
,
"is_fp8_available"
,
"update_collections"
,
"get_delayed_scaling"
,
"NVTE_FP8_COLLECTION_NAME"
,
]
_is_fp8_available
=
None
_is_fp8_available
=
None
_reason_for_no_fp8
=
""
_reason_for_no_fp8
=
""
...
@@ -87,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
...
@@ -87,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message
A tuple of (bool, str) indicating support and any error message
"""
"""
gpu_arch
=
get_device_compute_capability
(
gpu_id
)
gpu_arch
=
get_device_compute_capability
(
gpu_id
)
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
return
_check_delayed_scaling_fp8_support
(
gpu_arch
)
return
_check_delayed_scaling_fp8_support
(
gpu_arch
)
if
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
return
_check_block_scaling_fp8_support
(
gpu_arch
)
return
_check_block_scaling_fp8_support
(
gpu_arch
)
return
(
False
,
"Unsupported scaling_mode!"
)
return
(
False
,
"Unsupported scaling_mode!"
)
def
is_fp8_available
(
def
is_fp8_available
(
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
gpu_id
=
None
,
gpu_id
=
None
,
)
->
Tuple
[
bool
,
str
]:
)
->
Tuple
[
bool
,
str
]:
"""Check if FP8 is available for the given scaling mode and GPU.
"""Check if FP8 is available for the given scaling mode and GPU.
...
@@ -172,37 +179,12 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
...
@@ -172,37 +179,12 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
ValueError: If the recipe type is not supported
ValueError: If the recipe type is not supported
"""
"""
if
isinstance
(
fp8_recipe
,
recipe
.
DelayedScaling
):
if
isinstance
(
fp8_recipe
,
recipe
.
DelayedScaling
):
return
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
return
ScalingMode
.
DELAYED_TENSOR_SCALING
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
return
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
return
ScalingMode
.
MXFP8_1D_SCALING
raise
ValueError
(
"Invalid fp8_recipe!"
)
raise
ValueError
(
"Invalid fp8_recipe!"
)
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
Collection
:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert
isinstance
(
original
,
(
dict
,
FrozenDict
))
assert
isinstance
(
new
,
(
dict
,
FrozenDict
))
frozen_original
=
FrozenDict
(
original
)
if
not
isinstance
(
original
,
FrozenDict
)
else
original
for
key
in
new
:
if
key
in
frozen_original
:
frozen_original
,
_
=
frozen_original
.
pop
(
key
)
new_coll
=
FrozenDict
({
**
new
,
**
frozen_original
})
if
not
isinstance
(
original
,
FrozenDict
):
new_coll
=
new_coll
.
unfreeze
()
return
new_coll
class
QuantizeConfig
:
class
QuantizeConfig
:
"""Configuration class for quantization settings.
"""Configuration class for quantization settings.
...
@@ -227,7 +209,7 @@ class QuantizeConfig:
...
@@ -227,7 +209,7 @@ class QuantizeConfig:
INITIALIZED
=
False
INITIALIZED
=
False
MARGIN
:
float
=
0.0
MARGIN
:
float
=
0.0
COLLECTION_NAME
:
str
=
"
quantize
_meta"
COLLECTION_NAME
:
str
=
"
fp8
_meta
s
"
FP8_FORMAT
:
recipe
.
Format
=
recipe
.
Format
.
HYBRID
FP8_FORMAT
:
recipe
.
Format
=
recipe
.
Format
.
HYBRID
FWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
0
]
FWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
0
]
BWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
1
]
BWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
1
]
...
@@ -235,7 +217,7 @@ class QuantizeConfig:
...
@@ -235,7 +217,7 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD
:
bool
=
False
FP8_2X_ACC_DGRAD
:
bool
=
False
FP8_2X_ACC_WGRAD
:
bool
=
False
FP8_2X_ACC_WGRAD
:
bool
=
False
IF_QUANTIZE_2X
:
bool
=
False
IF_QUANTIZE_2X
:
bool
=
False
SCALING_MODE
:
ScalingMode
=
ScalingMode
.
NVTE_
NO_SCALING
SCALING_MODE
:
ScalingMode
=
ScalingMode
.
NO_SCALING
# DelayedScaling
# DelayedScaling
AMAX_HISTORY_LEN
:
int
=
1024
AMAX_HISTORY_LEN
:
int
=
1024
...
@@ -271,11 +253,11 @@ class QuantizeConfig:
...
@@ -271,11 +253,11 @@ class QuantizeConfig:
cls
.
MARGIN
=
0.0
cls
.
MARGIN
=
0.0
cls
.
FP8_FORMAT
=
recipe
.
Format
.
HYBRID
cls
.
FP8_FORMAT
=
recipe
.
Format
.
HYBRID
cls
.
FWD_DTYPE
,
cls
.
BWD_DTYPE
=
_format2dtypes
(
cls
.
FP8_FORMAT
)
cls
.
FWD_DTYPE
,
cls
.
BWD_DTYPE
=
_format2dtypes
(
cls
.
FP8_FORMAT
)
cls
.
SCALING_MODE
=
ScalingMode
.
NVTE_
NO_SCALING
cls
.
SCALING_MODE
=
ScalingMode
.
NO_SCALING
cls
.
FP8_2X_ACC_FPROP
=
False
cls
.
FP8_2X_ACC_FPROP
=
False
cls
.
FP8_2X_ACC_DGRAD
=
False
cls
.
FP8_2X_ACC_DGRAD
=
False
cls
.
FP8_2X_ACC_WGRAD
=
False
cls
.
FP8_2X_ACC_WGRAD
=
False
cls
.
SCALING_MODE
=
ScalingMode
.
NVTE_
NO_SCALING
cls
.
SCALING_MODE
=
ScalingMode
.
NO_SCALING
cls
.
IF_QUANTIZE_2X
=
False
cls
.
IF_QUANTIZE_2X
=
False
# DelayedScaling
# DelayedScaling
cls
.
AMAX_HISTORY_LEN
=
1024
cls
.
AMAX_HISTORY_LEN
=
1024
...
@@ -414,3 +396,56 @@ def fp8_autocast(
...
@@ -414,3 +396,56 @@ def fp8_autocast(
yield
yield
finally
:
finally
:
Config
.
finalize
()
Config
.
finalize
()
def
get_delayed_scaling
():
r
"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo
=
(
"max"
if
QuantizeConfig
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
else
"most_recent"
)
return
recipe
.
DelayedScaling
(
margin
=
int
(
QuantizeConfig
.
MARGIN
),
fp8_format
=
QuantizeConfig
.
FP8_FORMAT
,
amax_history_len
=
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_compute_algo
=
amax_compute_algo
,
)
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
Collection
:
r
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert
isinstance
(
original
,
(
dict
,
FrozenDict
))
assert
isinstance
(
new
,
(
dict
,
FrozenDict
))
frozen_original
=
FrozenDict
(
original
)
if
not
isinstance
(
original
,
FrozenDict
)
else
original
for
key
in
new
:
if
key
in
frozen_original
:
frozen_original
,
_
=
frozen_original
.
pop
(
key
)
new_coll
=
FrozenDict
({
**
new
,
**
frozen_original
})
if
not
isinstance
(
original
,
FrozenDict
):
new_coll
=
new_coll
.
unfreeze
()
return
new_coll
NVTE_FP8_COLLECTION_NAME
=
QuantizeConfig
.
COLLECTION_NAME
transformer_engine/jax/quantize/quantizer.py
View file @
ab3e5a92
...
@@ -14,7 +14,7 @@ from typing import Union, Optional
...
@@ -14,7 +14,7 @@ from typing import Union, Optional
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax.tree_util
import
register_pytree_node_class
from
jax.tree_util
import
register_pytree_node_class
from
transformer_engine_jax
import
Quantize
Axis
from
transformer_engine_jax
import
Quantize
Layout
from
.scaling_modes
import
ScalingMode
from
.scaling_modes
import
ScalingMode
from
.tensor
import
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensorFactory
from
.tensor
import
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensorFactory
...
@@ -24,7 +24,7 @@ from .helper import (
...
@@ -24,7 +24,7 @@ from .helper import (
)
)
__all__
=
[
__all__
=
[
"Quantize
Axis
"
,
"Quantize
Layout
"
,
"Quantizer"
,
"Quantizer"
,
"QuantizerSet"
,
"QuantizerSet"
,
"DelayedScaleQuantizer"
,
"DelayedScaleQuantizer"
,
...
@@ -45,12 +45,12 @@ class Quantizer(ABC):
...
@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes:
Attributes:
q_dtype: The data type for quantized values
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
scaling_mode: The scaling mode to use for quantization
q_
axis
: The quantization axis (row-wise, column-wise, or both)
q_
layout
: The quantization axis (row-wise, column-wise, or both)
"""
"""
q_dtype
:
jnp
.
dtype
q_dtype
:
jnp
.
dtype
scaling_mode
:
ScalingMode
scaling_mode
:
ScalingMode
q_
axis
:
Quantize
Axis
q_
layout
:
Quantize
Layout
def
tree_flatten
(
self
):
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
"""Flatten the quantizer for JAX tree operations.
...
@@ -59,7 +59,7 @@ class Quantizer(ABC):
...
@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
Tuple of (children, aux_data) for tree operations
"""
"""
children
=
()
children
=
()
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
axis
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
layout
)
return
(
children
,
aux_data
)
return
(
children
,
aux_data
)
@
classmethod
@
classmethod
...
@@ -85,30 +85,31 @@ class Quantizer(ABC):
...
@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns:
Returns:
True if using both row-wise and column-wise quantization
True if using both row-wise and column-wise quantization
"""
"""
return
self
.
q_
axis
==
Quantize
Axis
.
ROWWISE_COLWISE
return
self
.
q_
layout
==
Quantize
Layout
.
ROWWISE_COLWISE
@
abstractmethod
@
abstractmethod
def
get_layout
(
self
)
->
str
:
def
get_
data_
layout
(
self
)
->
str
:
"""Get the data layout.
"""Get the data
data_
layout.
Returns:
Returns:
Data layout in string format
Data
data_
layout in string format
"""
"""
@
abstractmethod
@
abstractmethod
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Core quantization function to be implemented by subclasses.
"""Core quantization function to be implemented by subclasses.
Args:
Args:
x: Input tensor to quantize
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor1x containing the quantized data
A ScaledTensor1x containing the quantized data
"""
"""
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
):
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
):
"""Quantize a tensor using the internal _quantize_func().
"""Quantize a tensor using the internal _quantize_func().
Args:
Args:
...
@@ -116,21 +117,26 @@ class Quantizer(ABC):
...
@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
"""
if
(
is_rowwise
and
is_colwise
)
or
self
.
is_2x2x
():
if
(
is_rowwise
and
is_colwise
)
or
self
.
is_2x2x
():
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
colwise_tensor
=
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
)
colwise_tensor
=
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
if
is_colwise
:
if
is_colwise
:
return
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
)
return
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
):
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
):
"""Get shapes for scale tensors.
"""Get shapes for scale tensors.
Args:
Args:
...
@@ -140,7 +146,7 @@ class Quantizer(ABC):
...
@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns:
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
"""
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
)
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
,
flatten_axis
)
def
get_scale_dtype
(
self
):
def
get_scale_dtype
(
self
):
"""Get the data type for scale tensors.
"""Get the data type for scale tensors.
...
@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
...
@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes:
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_
axis
: Quantization axis (default: ROWWISE_COLWISE)
q_
layout
: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
scale: Current scaling factor
amax_history: History of maximum absolute values
amax_history: History of maximum absolute values
"""
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
scaling_mode
:
ScalingMode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
q_
axis
:
Quantize
Axis
=
Quantize
Axis
.
ROWWISE_COLWISE
q_
layout
:
Quantize
Layout
=
Quantize
Layout
.
ROWWISE_COLWISE
scale
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
ones
((
1
,),
jnp
.
float32
))
scale
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
ones
((
1
,),
jnp
.
float32
))
amax_history
:
jnp
.
ndarray
=
field
(
amax_history
:
jnp
.
ndarray
=
field
(
...
@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
...
@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
Tuple of (children, aux_data) for tree operations
"""
"""
children
=
(
self
.
scale
,
self
.
amax_history
)
children
=
(
self
.
scale
,
self
.
amax_history
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
axis
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
layout
)
return
(
children
,
aux_data
)
return
(
children
,
aux_data
)
def
get_layout
(
self
)
->
str
:
def
get_
data_
layout
(
self
)
->
str
:
"""Get the data layout string.
"""Get the data
data_
layout string.
Returns:
Returns:
Data layout in string format
Data
data_
layout in string format
Raises:
Raises:
ValueError: If quantization axis is invalid
ValueError: If quantization axis is invalid
"""
"""
layout
=
"NT"
data_layout
=
"NT"
if
self
.
q_axis
==
QuantizeAxis
.
ROWWISE_COLWISE
:
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE_COLWISE
:
return
layout
return
data_layout
if
self
.
q_axis
==
QuantizeAxis
.
ROWWISE
:
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE
:
return
layout
[
0
]
return
data_layout
[
0
]
if
self
.
q_axis
==
QuantizeAxis
.
COLWISE
:
if
self
.
q_layout
==
QuantizeLayout
.
COLWISE
:
return
layout
[
1
]
return
data_layout
[
1
]
raise
ValueError
(
f
"Invalid q_axis:
{
self
.
q_axis
}
"
)
raise
ValueError
(
f
"Invalid q_layout:
{
self
.
q_layout
}
"
)
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Quantize function helper for delayed scaling FP8.
"""Quantize function helper for delayed scaling FP8.
Args:
Args:
x: Input tensor to quantize
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor1x containing the quantized data
A ScaledTensor1x containing the quantized data
"""
"""
...
@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
...
@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv
=
scale_inv
,
scale_inv
=
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
)
def
quantize
(
self
,
x
,
is_rowwise
:
bool
=
None
,
is_colwise
:
bool
=
None
,
dq_dtype
=
None
):
def
quantize
(
self
,
x
,
is_rowwise
:
bool
=
None
,
is_colwise
:
bool
=
None
,
dq_dtype
=
None
,
flatten_axis
=-
1
):
"""Quantize a tensor using the internal _quantize_func().
"""Quantize a tensor using the internal _quantize_func().
Args:
Args:
...
@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
...
@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
"""
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
if
flatten_axis
<
0
:
flatten_axis
+=
x
.
ndim
assert
0
<
flatten_axis
<
x
.
ndim
,
"flatten_axis is out of bounds!"
is_rowwise
=
(
is_rowwise
=
(
is_rowwise
is_rowwise
if
is_rowwise
is
not
None
if
is_rowwise
is
not
None
else
(
self
.
q_
axis
==
Quantize
Axis
.
ROWWISE
or
self
.
is_2x2x
())
else
(
self
.
q_
layout
==
Quantize
Layout
.
ROWWISE
or
self
.
is_2x2x
())
)
)
is_colwise
=
(
is_colwise
=
(
is_colwise
is_colwise
if
is_colwise
is
not
None
if
is_colwise
is
not
None
else
(
self
.
q_
axis
==
Quantize
Axis
.
COLWISE
or
self
.
is_2x2x
())
else
(
self
.
q_
layout
==
Quantize
Layout
.
COLWISE
or
self
.
is_2x2x
())
)
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
colwise_tensor
=
None
colwise_tensor
=
None
if
is_colwise
:
if
is_colwise
:
colwise_tensor
=
ScaledTensorFactory
.
create_1x
(
colwise_tensor
=
ScaledTensorFactory
.
create_1x
(
data
=
jnp
.
transpose
(
rowwise_tensor
.
data
,
(
-
1
,
*
range
(
rowwise_tensor
.
data
.
ndim
-
1
))),
data
=
jnp
.
transpose
(
rowwise_tensor
.
data
,
(
*
range
(
flatten_axis
,
x
.
ndim
),
*
range
(
flatten_axis
))
),
scale_inv
=
rowwise_tensor
.
scale_inv
,
scale_inv
=
rowwise_tensor
.
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
dq_dtype
=
dq_dtype
,
is_colwise
=
True
,
is_colwise
=
True
,
layout
=
"T"
,
data_layout
=
"T"
,
flatten_axis
=
flatten_axis
,
)
)
if
is_colwise
and
is_rowwise
:
if
is_colwise
and
is_rowwise
:
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
...
@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
...
@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes:
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_
axis
: Quantization axis (default: ROWWISE_COLWISE)
q_
layout
: Quantization axis (default: ROWWISE_COLWISE)
"""
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
scaling_mode
:
ScalingMode
=
ScalingMode
.
MXFP8_1D_SCALING
q_
axis
:
Quantize
Axis
=
Quantize
Axis
.
ROWWISE_COLWISE
q_
layout
:
Quantize
Layout
=
Quantize
Layout
.
ROWWISE_COLWISE
def
get_layout
(
self
)
->
str
:
def
get_
data_
layout
(
self
)
->
str
:
"""Get the data layout string.
"""Get the data
data_
layout string.
Returns:
Returns:
Data layout in string format
Data
data_
layout in string format
"""
"""
if
self
.
is_2x2x
():
if
self
.
is_2x2x
():
return
"NN"
return
"NN"
return
"N"
return
"N"
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Quantize function helper for block scaling FP8.
"""Quantize function helper for block scaling FP8.
Args:
Args:
x: Input tensor to quantize
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor1x containing the quantized data
A ScaledTensor1x containing the quantized data
"""
"""
# TODO(Phuong): use quantize_func from JAX
# TODO(Phuong): use quantize_func from JAX
if
flatten_axis
<
0
:
flatten_axis
=
x
.
ndim
+
flatten_axis
assert
(
0
<=
flatten_axis
<
x
.
ndim
),
f
"Invalid flatten_axis:
{
flatten_axis
}
for tensor of shape
{
x
.
shape
}
"
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
x_shape
=
x
.
shape
x_shape
=
x
.
shape
scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
x_shape
,
is_colwise
,
is_padded
=
False
)
scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
x_shape
,
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
scale_dtype
=
self
.
scaling_mode
.
get_scale_dtype
()
scale_dtype
=
self
.
scaling_mode
.
get_scale_dtype
()
x
=
x
.
reshape
(
x
=
x
.
reshape
(
*
x_shape
[:
-
2
],
*
x_shape
[:
flatten_axis
-
1
],
scale_shape
[
-
2
],
scale_shape
[
flatten_axis
-
1
],
int
(
x_shape
[
-
2
]
/
scale_shape
[
-
2
]),
int
(
x_shape
[
flatten_axis
-
1
]
/
scale_shape
[
flatten_axis
-
1
]),
*
x_shape
[
flatten_axis
:
-
1
],
scale_shape
[
-
1
],
scale_shape
[
-
1
],
int
(
x_shape
[
-
1
]
/
scale_shape
[
-
1
]),
int
(
x_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
),
axis
=
(
-
3
,
-
1
),
keepdims
=
True
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
),
axis
=
(
flatten_axis
+
2
-
2
,
-
1
),
keepdims
=
True
)
MAX
=
jnp
.
finfo
(
self
.
q_dtype
).
max
.
astype
(
jnp
.
float32
)
MAX
=
jnp
.
finfo
(
self
.
q_dtype
).
max
.
astype
(
jnp
.
float32
)
scales
=
amax
.
astype
(
jnp
.
float32
)
/
MAX
scales
=
amax
.
astype
(
jnp
.
float32
)
/
MAX
...
@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
...
@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self
.
scaling_mode
,
self
.
scaling_mode
,
is_colwise
=
is_colwise
,
is_colwise
=
is_colwise
,
dq_dtype
=
dq_dtype
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
)
def
_cast_to_e8m0_with_rounding_up
(
self
,
scales
):
def
_cast_to_e8m0_with_rounding_up
(
self
,
scales
):
...
@@ -500,8 +530,8 @@ class QuantizerFactory:
...
@@ -500,8 +530,8 @@ class QuantizerFactory:
"""
"""
quantizer_type_map
=
{
quantizer_type_map
=
{
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
DelayedScaleQuantizer
,
ScalingMode
.
DELAYED_TENSOR_SCALING
:
DelayedScaleQuantizer
,
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
BlockScaleQuantizer
,
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScaleQuantizer
,
}
}
@
staticmethod
@
staticmethod
...
@@ -509,7 +539,7 @@ class QuantizerFactory:
...
@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers
:
int
=
1
,
n_quantizers
:
int
=
1
,
scaling_mode
:
ScalingMode
=
None
,
scaling_mode
:
ScalingMode
=
None
,
q_dtype
:
jnp
.
dtype
=
None
,
q_dtype
:
jnp
.
dtype
=
None
,
q_
axis
:
Quantize
Axis
=
None
,
q_
layout
:
Quantize
Layout
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Quantizer
:
)
->
Quantizer
:
"""Create one or more quantizers with specified parameters.
"""Create one or more quantizers with specified parameters.
...
@@ -518,15 +548,17 @@ class QuantizerFactory:
...
@@ -518,15 +548,17 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_dtype: Quantization data type
q_axis: Quantization axis
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization
**kwargs: Additional arguments for quantizer initialization
Returns:
Returns:
A single quantizer or tuple of quantizers
A single quantizer or tuple of quantizers
"""
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING
assert
isinstance
(
scaling_mode
,
ScalingMode
),
"Invalid scaling_mode type"
if
scaling_mode
in
(
ScalingMode
.
NVTE_NO_SCALING
,
ScalingMode
.
NVTE_INVALID_SCALING
):
# import pdb; pdb.set_trace()
if
scaling_mode
==
ScalingMode
.
NO_SCALING
:
quantizers
=
[
None
]
*
n_quantizers
quantizers
=
[
None
]
*
n_quantizers
else
:
else
:
quantizers
=
[]
quantizers
=
[]
...
@@ -534,7 +566,7 @@ class QuantizerFactory:
...
@@ -534,7 +566,7 @@ class QuantizerFactory:
quantizer_type
=
QuantizerFactory
.
quantizer_type_map
.
get
(
scaling_mode
)
quantizer_type
=
QuantizerFactory
.
quantizer_type_map
.
get
(
scaling_mode
)
quantizers
.
append
(
quantizers
.
append
(
quantizer_type
(
quantizer_type
(
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
,
**
kwargs
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
,
**
kwargs
)
)
)
)
return
quantizers
[
0
]
if
len
(
quantizers
)
==
1
else
tuple
(
quantizers
)
return
quantizers
[
0
]
if
len
(
quantizers
)
==
1
else
tuple
(
quantizers
)
...
@@ -554,11 +586,11 @@ class QuantizerFactory:
...
@@ -554,11 +586,11 @@ class QuantizerFactory:
A QuantizerSet instance
A QuantizerSet instance
"""
"""
if
is_2x2x
:
if
is_2x2x
:
q_
axis
_x
=
q_
axis
_kernel
=
q_
axis
_dgrad
=
Quantize
Axis
.
ROWWISE_COLWISE
q_
layout
_x
=
q_
layout
_kernel
=
q_
layout
_dgrad
=
Quantize
Layout
.
ROWWISE_COLWISE
else
:
else
:
q_
axis
_x
=
Quantize
Axis
.
ROWWISE
q_
layout
_x
=
Quantize
Layout
.
ROWWISE
q_
axis
_kernel
=
Quantize
Axis
.
COLWISE
q_
layout
_kernel
=
Quantize
Layout
.
COLWISE
q_
axis
_dgrad
=
None
q_
layout
_dgrad
=
None
if
"quantize_meta_set"
in
kwargs
:
if
"quantize_meta_set"
in
kwargs
:
quantize_meta_set
=
kwargs
.
get
(
"quantize_meta_set"
)
quantize_meta_set
=
kwargs
.
get
(
"quantize_meta_set"
)
...
@@ -577,9 +609,11 @@ class QuantizerFactory:
...
@@ -577,9 +609,11 @@ class QuantizerFactory:
else
:
else
:
args_x
=
args_kernel
=
args_grad
=
{}
args_x
=
args_kernel
=
args_grad
=
{}
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_axis_x
,
**
args_x
)
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_layout_x
,
**
args_x
)
q_kernel
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_axis_kernel
,
**
args_kernel
)
q_kernel
=
QuantizerFactory
.
create
(
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_axis_dgrad
,
**
args_grad
)
1
,
scaling_mode
,
fwd_dtype
,
q_layout_kernel
,
**
args_kernel
)
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_layout_dgrad
,
**
args_grad
)
return
QuantizerSet
(
x
=
q_x
,
kernel
=
q_kernel
,
dgrad
=
q_dgrad
)
return
QuantizerSet
(
x
=
q_x
,
kernel
=
q_kernel
,
dgrad
=
q_dgrad
)
@
staticmethod
@
staticmethod
...
@@ -618,4 +652,4 @@ class QuantizerFactory:
...
@@ -618,4 +652,4 @@ class QuantizerFactory:
return
q_set
[
0
]
if
len
(
q_set
)
==
1
else
tuple
(
q_set
)
return
q_set
[
0
]
if
len
(
q_set
)
==
1
else
tuple
(
q_set
)
noop_quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
NVTE_
NO_SCALING
)
noop_quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
NO_SCALING
)
transformer_engine/jax/quantize/scaling_modes.py
View file @
ab3e5a92
...
@@ -16,11 +16,33 @@ from typing import Tuple, Dict
...
@@ -16,11 +16,33 @@ from typing import Tuple, Dict
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
operator
from
jax.experimental.custom_partitioning
import
CompoundFactor
from
jax.tree_util
import
register_pytree_node_class
from
jax.tree_util
import
register_pytree_node_class
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
transformer_engine_jax
import
JAXX_Scaling_Mode
__all__
=
[
"ScalingMode"
]
__all__
=
[
"QuantizeShardyRules"
,
"ScalingMode"
]
@
dataclass
class
QuantizeShardyRules
:
"""Information necessary to shard scale tensors with Shardy.
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
"""
input_spec
:
Tuple
[
str
]
rowwise_rule
:
Tuple
[
str
]
colwise_rule
:
Tuple
[
str
]
factor_sizes
:
Dict
[
str
,
int
]
class
ScalingModeMetadataImpl
(
ABC
):
class
ScalingModeMetadataImpl
(
ABC
):
...
@@ -40,7 +62,11 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -40,7 +62,11 @@ class ScalingModeMetadataImpl(ABC):
@
abstractmethod
@
abstractmethod
def
get_scale_shape
(
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors.
"""Get the shape for scale tensors.
...
@@ -48,11 +74,26 @@ class ScalingModeMetadataImpl(ABC):
...
@@ -48,11 +74,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Returns:
The shape for scale tensors
The shape for scale tensors
"""
"""
@
abstractmethod
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
class
DelayedScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
class
DelayedScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for delayed scaling mode.
"""Implementation for delayed scaling mode.
...
@@ -69,7 +110,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -69,7 +110,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
jnp
.
float32
return
jnp
.
float32
def
get_scale_shape
(
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors in delayed scaling.
"""Get the shape for scale tensors in delayed scaling.
...
@@ -77,6 +122,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -77,6 +122,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being scaled
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Returns:
The shape for scale tensors - (1,)
The shape for scale tensors - (1,)
...
@@ -84,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -84,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
del
data_shape
,
is_colwise
del
data_shape
,
is_colwise
return
(
1
,)
return
(
1
,)
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del
flatten_axis
input_spec
=
tuple
(
f
"x
{
i
}
"
for
i
in
range
(
input_rank
))
return
QuantizeShardyRules
(
input_spec
,
(
unique_var
,),
(
unique_var
,),
{})
class
BlockScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
class
BlockScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for block scaling mode.
"""Implementation for block scaling mode.
...
@@ -113,8 +176,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -113,8 +176,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
"""
return
jnp
.
float8_e8m0fnu
return
jnp
.
float8_e8m0fnu
def
_apply_scale_shape_correction
(
self
,
data_shape
,
n_scale_blocks
,
scale_block_dim
):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
if
len
(
data_shape
)
>
1
:
# handle last dim
assert
data_shape
[
-
1
]
%
scale_block_dim
==
0
last
=
data_shape
[
-
1
]
//
scale_block_dim
scale_shape
=
(
last
,)
assert
n_scale_blocks
%
last
==
0
n_scale_blocks
//=
last
# handle middle dim, exclude first and last
for
mid
in
reversed
(
data_shape
[
1
:
-
1
]):
scale_shape
=
(
mid
,)
+
scale_shape
assert
n_scale_blocks
%
mid
==
0
n_scale_blocks
//=
mid
scale_shape
=
(
n_scale_blocks
,)
+
scale_shape
else
:
scale_shape
=
(
n_scale_blocks
,)
assert
len
(
scale_shape
)
==
len
(
data_shape
),
f
"scale_shape
{
scale_shape
}
, data_shape
{
data_shape
}
"
return
scale_shape
def
get_scale_shape
(
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
int
,
...]:
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors in block scaling.
"""Get the shape for scale tensors in block scaling.
...
@@ -122,6 +212,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -122,6 +212,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being quantized
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Returns:
The shape for scale tensors
The shape for scale tensors
...
@@ -135,38 +226,87 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
...
@@ -135,38 +226,87 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x
,
block_y
=
self
.
_block_dims
block_x
,
block_y
=
self
.
_block_dims
alignment_x
,
alignment_y
=
block_alignment
alignment_x
,
alignment_y
=
block_alignment
seq_axis
=
len
(
data_shape
)
-
2
if
flatten_axis
<
0
:
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
assert
(
assert
(
data_shape
[
seq_axis
]
%
block_x
==
0
0
<
flatten_axis
<
len
(
data_shape
)
),
f
"Input data of shape
{
data_shape
}
should be padded by
{
block_x
}
in axes=
{
seq_axis
}
"
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
assert
data_shape
[
flatten_axis
-
1
]
%
block_x
==
0
,
(
f
"Data shape
{
data_shape
}
should be divisible by block_x
{
block_x
}
in axis"
f
"
{
flatten_axis
-
1
}
"
)
assert
(
assert
(
data_shape
[
-
1
]
%
block_y
==
0
data_shape
[
-
1
]
%
block_y
==
0
),
f
"
Input data of
shape
{
data_shape
}
should be
padded b
y
{
block_y
}
in axis -1"
),
f
"
Data
shape
{
data_shape
}
should be
divisible by block_
y
{
block_y
}
in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
flattened_first_dim
=
reduce
(
operator
.
mul
,
data_shape
[:
flatten_axis
],
1
)
n_block_seq
=
data_shape
[
seq_axis
]
//
block_x
flattened_last_dim
=
reduce
(
operator
.
mul
,
data_shape
[
flatten_axis
:],
1
)
n_block_y
=
data_shape
[
-
1
]
//
block_y
n_flat_first_dim
=
reduce
(
operator
.
mul
,
data_shape
[:
seq_axis
],
1
)
*
n_block_seq
assert
flattened_first_dim
%
block_x
==
0
,
(
f
"Flattened first dim - mutiplication of axes=
{
tuple
(
range
(
0
,
flatten_axis
))
}
of shape"
f
"
{
data_shape
}
- should be divisible by block_x
{
block_x
}
"
)
assert
flattened_last_dim
%
block_y
==
0
,
(
"Flattened last dim - mutiplication of"
f
" axes=
{
tuple
(
range
(
flatten_axis
,
len
(
data_shape
)))
}
of shape
{
data_shape
}
- should be"
f
" divisible by block_y
{
block_y
}
"
)
# Padding
n_block_x
=
int
(
flattened_first_dim
/
block_x
)
n_flat_first_dim
=
((
n_flat_first_dim
+
alignment_x
-
1
)
//
alignment_x
)
*
alignment_x
n_block_y
=
int
(
flattened_last_dim
/
block_y
)
n_block_y
=
((
n_block_y
+
alignment_y
-
1
)
//
alignment_y
)
*
alignment_y
out_shape
=
()
# padding
for
i
in
range
(
seq_axis
):
n_block_x
=
int
(((
n_block_x
+
alignment_x
-
1
)
//
alignment_x
)
*
alignment_x
)
d
=
data_shape
[
i
]
n_block_y
=
int
(((
n_block_y
+
alignment_y
-
1
)
//
alignment_y
)
*
alignment_y
)
out_shape
+=
(
d
,)
assert
n_flat_first_dim
%
d
==
0
n_flat_first_dim
//=
d
out_shape
+=
(
n_flat_first_dim
,
n_block_y
)
first_dim_scale_shape
=
self
.
_apply_scale_shape_correction
(
data_shape
[:
flatten_axis
],
n_block_x
,
block_x
)
last_dim_scale_shape
=
self
.
_apply_scale_shape_correction
(
data_shape
[
flatten_axis
:],
n_block_y
,
block_y
)
return
out_shape
return
(
*
first_dim_scale_shape
,
*
last_dim_scale_shape
)
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
Returns:
The Shardy rules for the scaling mode
"""
input_spec
=
[
f
"x
{
i
}
"
for
i
in
range
(
input_rank
)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var
=
unique_var
colwise_var
=
f
"
{
unique_var
}
_"
input_spec
[
flatten_axis
-
1
]
=
CompoundFactor
(
colwise_var
,
"block_size_colwise"
)
input_spec
[
-
1
]
=
CompoundFactor
(
rowwise_var
,
"block_size_rowwise"
)
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise
=
input_spec
.
copy
()
rowwise
[
-
1
]
=
rowwise_var
colwise
=
input_spec
.
copy
()
colwise
[
flatten_axis
-
1
]
=
colwise_var
# This implementation needs to be updated for different block dims.
assert
self
.
_block_dims
==
(
1
,
32
)
return
QuantizeShardyRules
(
tuple
(
input_spec
),
tuple
(
rowwise
),
tuple
(
colwise
),
{
"block_size_rowwise"
:
32
,
"block_size_colwise"
:
32
},
)
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
...
@@ -175,16 +315,14 @@ class ScalingMode(Enum):
...
@@ -175,16 +315,14 @@ class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization:
This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode
- NO_SCALING: No scaling applied
- NVTE_NO_SCALING: No scaling applied
"""
"""
NVTE_DELAYED_TENSOR_SCALING
=
0
NO_SCALING
=
JAXX_Scaling_Mode
.
NO_SCALING
NVTE_MXFP8_1D_SCALING
=
1
DELAYED_TENSOR_SCALING
=
JAXX_Scaling_Mode
.
DELAYED_TENSOR_SCALING
NVTE_INVALID_SCALING
=
2
MXFP8_1D_SCALING
=
JAXX_Scaling_Mode
.
MXFP8_1D_SCALING
NVTE_NO_SCALING
=
3
def
_get_impl
(
self
)
->
ScalingModeMetadataImpl
:
def
_get_impl
(
self
)
->
ScalingModeMetadataImpl
:
"""Get the implementation for this scaling mode.
"""Get the implementation for this scaling mode.
...
@@ -208,34 +346,54 @@ class ScalingMode(Enum):
...
@@ -208,34 +346,54 @@ class ScalingMode(Enum):
"""
"""
return
self
.
_get_impl
().
get_scale_dtype
()
return
self
.
_get_impl
().
get_scale_dtype
()
def
get_scale_shape_2x
(
self
,
data_shape
,
is_padded
=
True
)
->
Tuple
[
Tuple
[
int
]]:
def
get_scale_shape_2x
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
Tuple
[
int
]]:
"""Get shapes for both row-wise and column-wise scaling.
"""Get shapes for both row-wise and column-wise scaling.
Args:
Args:
data_shape: Shape of the data tensor
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
"""
rowwise_scale_shape
=
self
.
get_scale_shape
(
rowwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
False
,
is_padded
=
is_padded
data_shape
,
is_colwise
=
False
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
)
colwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
True
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
)
)
colwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
True
,
is_padded
=
is_padded
)
return
(
rowwise_scale_shape
,
colwise_scale_shape
)
return
(
rowwise_scale_shape
,
colwise_scale_shape
)
def
get_scale_shape
(
self
,
data_shape
,
is_colwise
,
is_padded
=
True
)
->
Tuple
[
int
]:
def
get_scale_shape
(
self
,
data_shape
,
is_colwise
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
int
]:
"""Get the shape for scale tensors in this mode.
"""Get the shape for scale tensors in this mode.
Args:
Args:
data_shape: Shape of the data tensor
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Returns:
The shape for scale tensors
The shape for scale tensors
"""
"""
return
self
.
_get_impl
().
get_scale_shape
(
data_shape
,
is_colwise
,
is_padded
)
return
self
.
_get_impl
().
get_scale_shape
(
data_shape
,
is_colwise
,
is_padded
,
flatten_axis
)
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
=-
1
)
->
Tuple
[
Tuple
[
str
]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_rank
,
unique_var
,
flatten_axis
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
"""Compare this scaling mode with another.
"""Compare this scaling mode with another.
...
@@ -273,8 +431,8 @@ class ScalingMode(Enum):
...
@@ -273,8 +431,8 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL
:
Dict
[
ScalingMode
,
ScalingModeMetadataImpl
]
=
{
SCALING_MODES_TO_IMPL
:
Dict
[
ScalingMode
,
ScalingModeMetadataImpl
]
=
{
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
DELAYED_TENSOR_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
BlockScalingModeMetadataImpl
(
block_dims
=
(
1
,
32
)),
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScalingModeMetadataImpl
(
block_dims
=
(
1
,
32
)),
# WAR
# WAR
ScalingMode
.
NVTE_
NO_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
NO_SCALING
:
DelayedScalingModeMetadataImpl
(),
}
}
transformer_engine/jax/quantize/tensor.py
View file @
ab3e5a92
...
@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
...
@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax.tree_util
import
register_pytree_node_class
from
jax.tree_util
import
register_pytree_node_class
from
transformer_engine_jax
import
Quantize
Axis
from
transformer_engine_jax
import
Quantize
Layout
from
.scaling_modes
import
ScalingMode
from
.scaling_modes
import
ScalingMode
from
.dequantizer
import
Dequantizer
from
.dequantizer
import
Dequantizer
...
@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
...
@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
ValueError: If called on a tensor that doesn't support column-wise access
ValueError: If called on a tensor that doesn't support column-wise access
"""
"""
@
abstractmethod
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
@
register_pytree_node_class
@
register_pytree_node_class
@
dataclass
@
dataclass
...
@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: The data type for dequantized values
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
"""
"""
data
:
jnp
.
ndarray
data
:
jnp
.
ndarray
...
@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype
:
jnp
.
dtype
dq_dtype
:
jnp
.
dtype
_dq_func
:
Callable
_dq_func
:
Callable
is_colwise
:
bool
is_colwise
:
bool
layout
:
str
data_layout
:
str
flatten_axis
:
int
=
-
1
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Validates and adjusts the scale_inv shape after initialization.
"""Validates and adjusts the scale_inv shape after initialization.
...
@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
and quantization direction. Pads the scale_inv if necessary.
"""
"""
flatten_axis
=
(
len
(
self
.
data
.
shape
)
+
self
.
flatten_axis
if
self
.
flatten_axis
<
0
else
self
.
flatten_axis
)
assert
(
0
<
flatten_axis
<
len
(
self
.
data
.
shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
self
.
data
.
shape
}
"
if
self
.
data_layout
==
"T"
:
flatten_axis
=
self
.
data
.
ndim
-
flatten_axis
self
.
flatten_axis
=
flatten_axis
expected_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
expected_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
,
flatten_axis
=
flatten_axis
)
)
expected_unpadded_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
expected_unpadded_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
)
if
self
.
scale_inv
.
shape
!=
expected_scale_shape
:
if
self
.
scale_inv
.
shape
!=
expected_scale_shape
:
assert
self
.
scale_inv
.
shape
==
expected_unpadded_scale_shape
,
(
assert
self
.
scale_inv
.
shape
==
expected_unpadded_scale_shape
,
(
...
@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
A tuple containing (children, aux_data) for tree operations
A tuple containing (children, aux_data) for tree operations
"""
"""
children
=
(
self
.
data
,
self
.
scale_inv
)
children
=
(
self
.
data
,
self
.
scale_inv
)
aux_data
=
(
self
.
scaling_mode
,
self
.
dq_dtype
,
self
.
_dq_func
,
self
.
is_colwise
,
self
.
layout
)
aux_data
=
(
self
.
scaling_mode
,
self
.
dq_dtype
,
self
.
_dq_func
,
self
.
is_colwise
,
self
.
data_layout
,
self
.
flatten_axis
,
)
return
(
children
,
aux_data
)
return
(
children
,
aux_data
)
def
dequantize
(
self
):
def
dequantize
(
self
):
...
@@ -183,6 +214,45 @@ class ScaledTensor1x(ScaledTensor):
...
@@ -183,6 +214,45 @@ class ScaledTensor1x(ScaledTensor):
raise
ValueError
(
"Calling get_colwise_tensor() from a rowwise ScaledTensor1x!"
)
raise
ValueError
(
"Calling get_colwise_tensor() from a rowwise ScaledTensor1x!"
)
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if
not
logical_axis_names
:
return
self
# axis_names were given for N layout, so needs to be transpose for T layout
if
self
.
data_layout
==
"T"
:
assert
self
.
flatten_axis
>
0
flatten_axis
=
-
self
.
flatten_axis
axis_names
=
(
*
logical_axis_names
[
flatten_axis
:],
*
logical_axis_names
[:
flatten_axis
])
else
:
axis_names
=
logical_axis_names
data
=
with_sharding_constraint_by_logical_axes
(
self
.
data
,
axis_names
)
if
self
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
# TODO(Phuong): Handle padding !?
scale_inv
=
with_sharding_constraint_by_logical_axes
(
self
.
scale_inv
,
axis_names
)
else
:
scale_inv
=
self
.
scale_inv
return
ScaledTensor1x
(
data
=
data
,
scale_inv
=
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
self
.
dq_dtype
,
_dq_func
=
self
.
_dq_func
,
is_colwise
=
self
.
is_colwise
,
data_layout
=
self
.
data_layout
,
flatten_axis
=
self
.
flatten_axis
,
)
@
register_pytree_node_class
@
register_pytree_node_class
@
dataclass
@
dataclass
...
@@ -233,6 +303,27 @@ class ScaledTensor2x(ScaledTensor):
...
@@ -233,6 +303,27 @@ class ScaledTensor2x(ScaledTensor):
"""
"""
return
self
.
colwise_tensor
return
self
.
colwise_tensor
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if
not
logical_axis_names
:
return
self
rowwise_tensor
=
self
.
rowwise_tensor
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
colwise_tensor
=
self
.
colwise_tensor
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
@
dataclass
@
dataclass
class
ScaledTensorFactory
:
class
ScaledTensorFactory
:
...
@@ -244,7 +335,13 @@ class ScaledTensorFactory:
...
@@ -244,7 +335,13 @@ class ScaledTensorFactory:
@
staticmethod
@
staticmethod
def
create_1x
(
def
create_1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
is_colwise
=
False
,
layout
=
"N"
data
,
scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
is_colwise
=
False
,
data_layout
=
"N"
,
flatten_axis
=-
1
,
):
):
"""Creates a single-scale quantized tensor.
"""Creates a single-scale quantized tensor.
...
@@ -254,13 +351,16 @@ class ScaledTensorFactory:
...
@@ -254,13 +351,16 @@ class ScaledTensorFactory:
scaling_mode: The scaling mode for quantization
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor1x instance
A ScaledTensor1x instance
"""
"""
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
return
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
,
layout
)
return
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
,
data_layout
,
flatten_axis
)
@
staticmethod
@
staticmethod
def
create_2x
(
def
create_2x
(
...
@@ -270,7 +370,8 @@ class ScaledTensorFactory:
...
@@ -270,7 +370,8 @@ class ScaledTensorFactory:
colwise_scale_inv
,
colwise_scale_inv
,
scaling_mode
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
dq_dtype
=
jnp
.
bfloat16
,
layout
=
"NN"
,
data_layout
=
"NN"
,
flatten_axis
=-
1
,
):
):
"""Creates a double-scale quantized tensor.
"""Creates a double-scale quantized tensor.
...
@@ -281,7 +382,8 @@ class ScaledTensorFactory:
...
@@ -281,7 +382,8 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
Returns:
Returns:
A ScaledTensor2x instance
A ScaledTensor2x instance
...
@@ -294,7 +396,8 @@ class ScaledTensorFactory:
...
@@ -294,7 +396,8 @@ class ScaledTensorFactory:
dq_dtype
,
dq_dtype
,
dq_func
,
dq_func
,
is_colwise
=
False
,
is_colwise
=
False
,
layout
=
layout
[
0
],
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
)
)
colwise_tensor
=
ScaledTensor1x
(
colwise_tensor
=
ScaledTensor1x
(
colwise_data
,
colwise_data
,
...
@@ -303,7 +406,8 @@ class ScaledTensorFactory:
...
@@ -303,7 +406,8 @@ class ScaledTensorFactory:
dq_dtype
,
dq_dtype
,
dq_func
,
dq_func
,
is_colwise
=
True
,
is_colwise
=
True
,
layout
=
layout
[
1
],
data_layout
=
data_layout
[
1
],
flatten_axis
=
flatten_axis
,
)
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
...
@@ -315,8 +419,9 @@ class ScaledTensorFactory:
...
@@ -315,8 +419,9 @@ class ScaledTensorFactory:
colwise_scale_inv
:
jnp
.
ndarray
,
colwise_scale_inv
:
jnp
.
ndarray
,
scaling_mode
:
ScalingMode
,
scaling_mode
:
ScalingMode
,
dq_dtype
:
jnp
.
dtype
=
jnp
.
bfloat16
,
dq_dtype
:
jnp
.
dtype
=
jnp
.
bfloat16
,
layout
:
str
=
"NN"
,
data_layout
:
str
=
"NN"
,
q_axis
:
QuantizeAxis
=
QuantizeAxis
.
ROWWISE
,
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE
,
flatten_axis
:
int
=
-
1
,
):
):
"""Creates a scaled tensor based on the quantization axis.
"""Creates a scaled tensor based on the quantization axis.
...
@@ -327,13 +432,13 @@ class ScaledTensorFactory:
...
@@ -327,13 +432,13 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
data_
layout: The
data_
layout specification (default: "NN")
q_
axis
: The quantization axis (default: ROWWISE)
q_
layout
: The quantization axis (default: ROWWISE)
Returns:
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_
axis
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_
layout
"""
"""
if
q_
axis
==
Quantize
Axis
.
ROWWISE_COLWISE
:
if
q_
layout
==
Quantize
Layout
.
ROWWISE_COLWISE
:
return
ScaledTensorFactory
.
create_2x
(
return
ScaledTensorFactory
.
create_2x
(
data
,
data
,
scale_inv
,
scale_inv
,
...
@@ -341,12 +446,19 @@ class ScaledTensorFactory:
...
@@ -341,12 +446,19 @@ class ScaledTensorFactory:
colwise_scale_inv
,
colwise_scale_inv
,
scaling_mode
,
scaling_mode
,
dq_dtype
,
dq_dtype
,
layout
=
layout
,
data_layout
=
data_layout
,
flatten_axis
=
flatten_axis
,
)
)
is_colwise
=
q_
axis
==
Quantize
Axis
.
COLWISE
is_colwise
=
q_
layout
==
Quantize
Layout
.
COLWISE
return
ScaledTensorFactory
.
create_1x
(
return
ScaledTensorFactory
.
create_1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
is_colwise
=
is_colwise
,
layout
=
layout
[
0
]
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
is_colwise
=
is_colwise
,
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
)
)
...
@@ -360,24 +472,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
...
@@ -360,24 +472,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns:
Returns:
The tensor with applied sharding constraints
The tensor with applied sharding constraints
"""
"""
if
isinstance
(
x
,
ScaledTensor1x
):
if
isinstance
(
x
,
ScaledTensor
):
return
ScaledTensor1x
(
return
x
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
data
=
with_sharding_constraint_by_logical_axes
(
x
.
data
,
logical_axis_names
),
scale_inv
=
x
.
scale_inv
,
scaling_mode
=
x
.
scaling_mode
,
dq_dtype
=
x
.
dq_dtype
,
_dq_func
=
x
.
_dq_func
,
is_colwise
=
x
.
is_colwise
,
layout
=
x
.
layout
,
)
if
isinstance
(
x
,
ScaledTensor2x
):
return
ScaledTensor2x
(
rowwise_tensor
=
with_sharding_constraint_by_logical_axes
(
x
.
rowwise_tensor
,
logical_axis_names
),
colwise_tensor
=
with_sharding_constraint_by_logical_axes
(
x
.
colwise_tensor
,
logical_axis_names
),
)
return
original_with_sharding_constraint_by_logical_axes
(
x
,
logical_axis_names
)
return
original_with_sharding_constraint_by_logical_axes
(
x
,
logical_axis_names
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment