Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1093 additions
and
1346 deletions
+1093
-1346
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+14
-8
transformer_engine/jax/csrc/extensions/activation.cpp
transformer_engine/jax/csrc/extensions/activation.cpp
+111
-87
transformer_engine/jax/csrc/extensions/gemm.cpp
transformer_engine/jax/csrc/extensions/gemm.cpp
+96
-110
transformer_engine/jax/csrc/extensions/misc.h
transformer_engine/jax/csrc/extensions/misc.h
+24
-1
transformer_engine/jax/csrc/extensions/normalization.cpp
transformer_engine/jax/csrc/extensions/normalization.cpp
+10
-11
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+9
-9
transformer_engine/jax/csrc/extensions/quantization.cpp
transformer_engine/jax/csrc/extensions/quantization.cpp
+93
-41
transformer_engine/jax/dense.py
transformer_engine/jax/dense.py
+37
-22
transformer_engine/jax/flax/module.py
transformer_engine/jax/flax/module.py
+90
-102
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+12
-4
transformer_engine/jax/layernorm_dense.py
transformer_engine/jax/layernorm_dense.py
+19
-9
transformer_engine/jax/layernorm_mlp.py
transformer_engine/jax/layernorm_mlp.py
+54
-21
transformer_engine/jax/praxis/__init__.py
transformer_engine/jax/praxis/__init__.py
+0
-9
transformer_engine/jax/praxis/module.py
transformer_engine/jax/praxis/module.py
+0
-311
transformer_engine/jax/praxis/transformer.py
transformer_engine/jax/praxis/transformer.py
+0
-408
transformer_engine/jax/quantize/dequantizer.py
transformer_engine/jax/quantize/dequantizer.py
+16
-7
transformer_engine/jax/quantize/helper.py
transformer_engine/jax/quantize/helper.py
+70
-35
transformer_engine/jax/quantize/quantizer.py
transformer_engine/jax/quantize/quantizer.py
+102
-68
transformer_engine/jax/quantize/scaling_modes.py
transformer_engine/jax/quantize/scaling_modes.py
+200
-42
transformer_engine/jax/quantize/tensor.py
transformer_engine/jax/quantize/tensor.py
+136
-41
No files found.
transformer_engine/jax/csrc/extensions.h
View file @
ab3e5a92
...
...
@@ -31,6 +31,9 @@
#include "transformer_engine/activation.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING
(
transformer_engine
::
jax
::
JAXX_Scaling_Mode
);
namespace
transformer_engine
{
namespace
jax
{
...
...
@@ -40,6 +43,12 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ActLuHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuDBiasQuantizeHandler
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
NormForwardHandler
);
...
...
@@ -47,7 +56,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
NVTE_Norm_Type
norm_type
,
int
scaling_mode
,
NVTE_Norm_Type
norm_type
,
JAXX_Scaling_Mode
scaling_mode
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
is_training
);
...
...
@@ -61,13 +71,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DequantizeHandler
);
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuDBiasQuantizeHandler
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
int
scaling_mode
,
bool
is_2x
);
JAXX_Scaling_Mode
scaling_mode
,
QuantizeLayout
q_layout
);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ScaledSoftmaxForwardHandler
);
...
...
transformer_engine/jax/csrc/extensions/activation.cpp
View file @
ab3e5a92
...
...
@@ -11,21 +11,13 @@
#include "transformer_engine/cast.h"
#include "xla/ffi/api/c_api.h"
namespace
{
bool
is_gated
(
NVTE_Activation_Type
act_type
)
{
return
act_type
==
NVTE_Activation_Type
::
GEGLU
||
act_type
==
NVTE_Activation_Type
::
SWIGLU
||
act_type
==
NVTE_Activation_Type
::
REGLU
||
act_type
==
NVTE_Activation_Type
::
QGEGLU
||
act_type
==
NVTE_Activation_Type
::
SREGLU
;
}
}
// namespace
namespace
transformer_engine
{
namespace
jax
{
Error_Type
ActLuFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
colwise_output_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
int64_t
act_enum
,
int64_t
scaling_mode
_enum
,
Result_Type
amax_buf
,
int64_t
act_enum
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x_int
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
...
...
@@ -42,40 +34,59 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto
n
=
input_dims
.
back
();
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
auto
act_len
=
input_dims
[
input_dims
.
size
()
-
2
];
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
is_2x
=
static_cast
<
bool
>
(
is_2x_int
);
auto
flatten_axis
=
output_buf
->
dimensions
().
size
()
-
1
;
// output does not have act axis
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
act_len
*
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_
scaling_mode
(
scaling_mode
)
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
}
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
scale_inv_buf
->
dimensions
(),
flatten_axis
,
scale_inv_buf
->
dimensions
().
size
())});
}
}
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
colwise_output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
auto
&
tmp_shape
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
colwise_output
,
out_dtype
,
tmp_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
colwise_scale_inv_buf
;
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
tmp_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
tmp_buf
->
dimensions
(),
flatten_axis
,
tmp_buf
->
dimensions
().
size
())});
}
}
}
switch
(
act_type
)
{
...
...
@@ -128,13 +139,13 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.
Ret
<
Buffer_Type
>
()
// scale_inv colwise
.
Ret
<
Buffer_Type
>
()
// amax
.
Attr
<
int64_t
>
(
"act_enum"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
),
FFI_CudaGraph_Traits
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
int
scaling_mode
,
bool
is_2x
)
{
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
dact_input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
...
...
@@ -153,7 +164,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto
dact_input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dact_input_shape
,
in_dtype
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
static_cast
<
NVTES
caling
M
ode
>
(
scaling_mode
));
auto
output_tensor
=
TensorWrapper
(
get_nvte_s
caling
_m
ode
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
...
...
@@ -162,8 +173,9 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
auto
&
tmp_shape
=
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
tmp_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
...
...
@@ -172,7 +184,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
}
}
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
NVTE
ScalingMode
::
NVTE_
DELAYED_TENSOR_SCALING
)
{
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_amax
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
...
...
@@ -190,22 +202,25 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
Error_Type
DActLuDBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
act_input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
output
_trans
_buf
,
Result_Type
scale_inv_buf
,
Result_Type
trans
_scale_inv_buf
,
Result_Type
amax_
out_
buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
int64_t
scaling_mode_enum
,
bool
is_2x
,
bool
is_dbias
,
int64_t
act_enum
)
{
Result_Type
output_buf
,
Result_Type
colwise_
output_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise
_scale_inv_buf
,
Result_Type
amax_buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
JAXX_Scaling_Mode
scaling_mode
,
int64_t
act_enum
,
bool
is_2x
,
bool
is_dbias
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
act_input
=
act_input_buf
.
untyped_data
();
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
auto
flatten_axis
=
output_buf
->
dimensions
().
size
()
-
2
;
// output has act axis
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans
_buf
->
untyped_data
();
auto
*
colwise_output
=
colwise_output
_buf
->
untyped_data
();
auto
*
dbias
=
dbias_buf
->
untyped_data
();
void
*
workspace
=
workspace_buf
->
untyped_data
();
...
...
@@ -213,67 +228,76 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto
act_input_dims
=
act_input_buf
.
dimensions
();
auto
workspace_dims
=
workspace_buf
->
dimensions
();
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto
input_ranks
=
input_dims
.
size
();
auto
act_input_ranks
=
act_input_dims
.
size
();
auto
m
=
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
1
);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto
n
=
act_input_dims
.
back
();
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
input_dims
.
back
()};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto
act_len
=
act_input_dims
[
act_input_dims
.
size
()
-
2
];
NVTE_CHECK
(
act_input_dims
.
back
()
==
input_dims
.
back
(),
"Shape mismatch between activation input and gradient input"
);
auto
m
=
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
2
);
auto
n
=
input_dims
.
back
();
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
*
act_len
,
m
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
*
act_len
};
std
::
vector
<
size_t
>
workspace_shape
(
workspace_dims
.
begin
(),
workspace_dims
.
end
());
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax_out
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax_out
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax_out
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
scale_inv_buf
->
dimensions
(),
flatten_axis
,
scale_inv_buf
->
dimensions
().
size
())});
}
}
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
auto
&
tmp_shape
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
colwise_output
,
out_dtype
,
tmp_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
colwise_scale_inv_buf
=
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
trans_scale_inv_buf
;
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
colwise_scale_inv_buf
;
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
tmp_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
tmp_buf
->
dimensions
(),
flatten_axis
,
tmp_buf
->
dimensions
().
size
())});
}
}
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK
(
!
(
is_gated
(
act_type
)
&&
is_dbias
),
"Unsupported DGatedActedDBias Fusion!"
);
NVTE_CHECK
(
!
(
scaling_mode
==
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
&&
is_2x
&&
is_gated
(
act_type
)),
NVTE_CHECK
(
!
(
act_len
==
2
&&
is_dbias
),
"Unsupported DGatedActedDBias Fusion!"
);
NVTE_CHECK
(
!
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
&&
is_2x
&&
act_len
==
2
),
"TE/common does not support delayed scaling for 2x with gated activations."
);
if
(
is_dbias
)
{
...
...
@@ -361,10 +385,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"act_enum"
)
.
Attr
<
bool
>
(
"is_2x"
)
.
Attr
<
bool
>
(
"is_dbias"
)
.
Attr
<
int64_t
>
(
"act_enum"
),
.
Attr
<
bool
>
(
"is_dbias"
),
FFI_CudaGraph_Traits
);
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/gemm.cpp
View file @
ab3e5a92
...
...
@@ -15,47 +15,34 @@
namespace
transformer_engine
{
namespace
jax
{
constexpr
static
size_t
MXFP8_BLOCK_SIZE
=
32
;
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
Error_Type
GroupedGemmImpl
(
uint8_t
*
lhs_ptr
,
const
DType
&
lhs_dtype
,
uint8_t
*
lhs_sinv_ptr
,
const
DType
&
lhs_sinv_dtype
,
uint8_t
*
rhs_ptr
,
const
DType
&
rhs_dtype
,
uint8_t
*
rhs_sinv_ptr
,
const
DType
&
rhs_sinv_dtype
,
uint8_t
*
bias_ptr
,
const
DType
&
bias_dtype
,
uint8_t
*
out_ptr
,
const
DType
&
out_dtype
,
uint8_t
*
workspace_ptr
,
const
size_t
workspace_size
,
size_t
num_gemms
,
int32_t
*
dim_list_ptr
,
const
int64_t
&
scaling_mode
,
cudaStream_t
stream
)
{
size_t
lhs_dtype_bytes
=
te_dtype_bytes
(
lhs_dtype
);
size_t
rhs_dtype_bytes
=
te_dtype_bytes
(
rhs_dtype
);
size_t
lhs_sinv_dtype_bytes
=
te_dtype_bytes
(
lhs_sinv_dtype
);
size_t
rhs_sinv_dtype_bytes
=
te_dtype_bytes
(
rhs_sinv_dtype
);
size_t
bias_dtype_bytes
=
te_dtype_bytes
(
bias_dtype
);
size_t
out_dtype_bytes
=
te_dtype_bytes
(
out_dtype
);
NVTE_CHECK
(
lhs_dtype_bytes
==
rhs_dtype_bytes
,
"sizeof(lhs_dtype) != sizeof(rhs_dtype)"
);
NVTE_CHECK
(
lhs_sinv_dtype_bytes
==
rhs_sinv_dtype_bytes
,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"
);
size_t
dim_list_bytes
=
sizeof
(
int32_t
)
*
3
*
num_gemms
;
std
::
unique_ptr
<
int32_t
[]
>
dim_list_host
=
std
::
make_unique
<
int32_t
[]
>
(
3
*
num_gemms
);
cudaMemcpyAsync
(
dim_list_host
.
get
(),
dim_list_ptr
,
dim_list_bytes
,
cudaMemcpyDeviceToHost
,
stream
);
// Note: This may break cudaGraph.
cudaStreamSynchronize
(
stream
);
Error_Type
GroupedGemmFFI
(
cudaStream_t
stream
,
Variadic_Buffer_Type
input_list
,
Variadic_Result_Type
output_list
,
int64_t
num_gemms
,
JAXX_Scaling_Mode
scaling_mode
,
int64_t
has_bias
)
{
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// Jax uses row-major
data_
layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// cuBLAS uses column-major
data_
layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
if
(
num_gemms
<=
0
)
{
return
ffi_with_cuda_error_check
();
}
size_t
expected_input_size
=
has_bias
?
5
*
num_gemms
:
4
*
num_gemms
;
size_t
expected_output_size
=
num_gemms
+
1
;
size_t
actual_input_size
=
input_list
.
size
();
size_t
actual_output_size
=
output_list
.
size
();
NVTE_CHECK
(
actual_input_size
==
expected_input_size
,
"Expected %zu input tensors, got %zu"
,
expected_input_size
,
actual_input_size
);
NVTE_CHECK
(
actual_output_size
==
expected_output_size
,
"Expected %zu output tensors, got %zu"
,
expected_output_size
,
actual_output_size
);
bool
trans_lhs
=
true
;
bool
trans_rhs
=
false
;
auto
num_math_sm
=
cuda
::
sm_count
()
-
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
0
);
...
...
@@ -79,10 +66,40 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
std
::
vector
<
NVTETensor
>
out_list
;
std
::
vector
<
NVTETensor
>
workspace_list
;
int
lhs_list_offset
=
0
;
int
rhs_list_offset
=
num_gemms
;
int
lhs_sinv_list_offset
=
2
*
num_gemms
;
int
rhs_sinv_list_offset
=
3
*
num_gemms
;
int
bias_list_offset
=
4
*
num_gemms
;
int
out_list_offset
=
0
;
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
size_t
m
=
dim_list_host
[
i
*
3
];
size_t
n
=
dim_list_host
[
i
*
3
+
1
];
size_t
k
=
dim_list_host
[
i
*
3
+
2
];
Buffer_Type
lhs_i
=
input_list
.
get
<
Buffer_Type
>
(
lhs_list_offset
+
i
).
value
();
Buffer_Type
rhs_i
=
input_list
.
get
<
Buffer_Type
>
(
rhs_list_offset
+
i
).
value
();
Buffer_Type
lhs_sinv_i
=
input_list
.
get
<
Buffer_Type
>
(
lhs_sinv_list_offset
+
i
).
value
();
Buffer_Type
rhs_sinv_i
=
input_list
.
get
<
Buffer_Type
>
(
rhs_sinv_list_offset
+
i
).
value
();
Result_Type
out_i
=
output_list
.
get
<
Buffer_Type
>
(
out_list_offset
+
i
).
value
();
DType
lhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_i
.
element_type
());
DType
rhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_i
.
element_type
());
DType
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
out_i
->
element_type
());
void
*
lhs_ptr
=
lhs_i
.
untyped_data
();
void
*
rhs_ptr
=
rhs_i
.
untyped_data
();
void
*
lhs_sinv_ptr
=
lhs_sinv_i
.
untyped_data
();
void
*
rhs_sinv_ptr
=
rhs_sinv_i
.
untyped_data
();
void
*
out_ptr
=
out_i
->
untyped_data
();
// Placeholder for bias since it can be empty
DType
bias_dtype
=
DType
::
kFloat32
;
void
*
bias_ptr
=
nullptr
;
auto
lhs_shape_
=
lhs_i
.
dimensions
();
auto
rhs_shape_
=
rhs_i
.
dimensions
();
// lhs and rhs has shape [1, m, k] and [1, n, k]
size_t
m
=
lhs_shape_
[
1
];
size_t
n
=
rhs_shape_
[
1
];
size_t
k
=
lhs_shape_
[
2
];
auto
lhs_shape
=
std
::
vector
<
size_t
>
{
m
,
k
};
auto
rhs_shape
=
std
::
vector
<
size_t
>
{
n
,
k
};
...
...
@@ -90,54 +107,54 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
auto
lhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
auto
rhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
if
(
scaling_mode
==
NVTE_NO_SCALING
||
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
auto
lhs_i
=
TensorWrapper
(
static_cast
<
void
*>
(
lhs_ptr
),
lhs_shape
,
lhs_dtype
,
nullptr
,
nullptr
,
reinterpret_cast
<
float
*>
(
lhs_sinv_ptr
));
auto
rhs_i
=
TensorWrapper
(
static_cast
<
void
*>
(
rhs_ptr
),
rhs_shape
,
rhs_dtype
,
nullptr
,
nullptr
,
reinterpret_cast
<
float
*>
(
rhs_sinv_ptr
));
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
}
else
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
NVTE_CHECK
(
k
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 K-dim being divisble by %d (got %d)"
,
MXFP8_BLOCK_SIZE
,
k
);
size_t
sinv_k
=
k
/
MXFP8_BLOCK_SIZE
;
lhs_sinv_shape
[
0
]
=
m
;
lhs_sinv_shape
[
1
]
=
sinv_k
;
rhs_sinv_shape
[
0
]
=
n
;
rhs_sinv_shape
[
1
]
=
sinv_k
;
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
NO_SCALING
||
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
float
*
amax_dptr
=
nullptr
;
float
*
scale_dptr
=
nullptr
;
auto
lhs_i_
=
TensorWrapper
(
lhs_ptr
,
lhs_shape
,
lhs_dtype
,
amax_dptr
,
scale_dptr
,
reinterpret_cast
<
float
*>
(
lhs_sinv_ptr
));
auto
rhs_i_
=
TensorWrapper
(
rhs_ptr
,
rhs_shape
,
rhs_dtype
,
amax_dptr
,
scale_dptr
,
reinterpret_cast
<
float
*>
(
rhs_sinv_ptr
));
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i_
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i_
));
}
else
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
{
// Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper
lhs_i
(
NVTE_MXFP8_1D_SCALING
);
TensorWrapper
rhs_i
(
NVTE_MXFP8_1D_SCALING
);
lhs_i
.
set_rowwise_data
(
static_cast
<
void
*>
(
lhs_ptr
),
lhs_dtype
,
lhs_shape
);
rhs_i
.
set_rowwise_data
(
static_cast
<
void
*>
(
rhs_ptr
),
rhs_dtype
,
rhs_shape
);
lhs_i
.
set_rowwise_scale_inv
(
static_cast
<
void
*>
(
lhs_sinv_ptr
),
DType
::
kFloat8E8M0
,
lhs_sinv_shape
);
rhs_i
.
set_rowwise_scale_inv
(
static_cast
<
void
*>
(
rhs_sinv_ptr
),
DType
::
kFloat8E8M0
,
rhs_sinv_shape
);
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
}
else
{
NVTE_ERROR
(
"Unsupported scaling mode: "
,
scaling_mode
);
auto
lhs_sinv_shape_
=
lhs_sinv_i
.
dimensions
();
auto
rhs_sinv_shape_
=
rhs_sinv_i
.
dimensions
();
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
lhs_sinv_shape
[
i
]
=
lhs_sinv_shape_
[
i
];
rhs_sinv_shape
[
i
]
=
rhs_sinv_shape_
[
i
];
}
auto
out_i
=
TensorWrapper
(
static_cast
<
void
*>
(
out_ptr
),
out_shape
,
out_dtype
);
lhs_ptr
+=
m
*
k
*
lhs_dtype_bytes
;
rhs_ptr
+=
n
*
k
*
rhs_dtype_bytes
;
out_ptr
+=
m
*
n
*
out_dtype_bytes
;
lhs_sinv_ptr
+=
lhs_sinv_shape
[
0
]
*
lhs_sinv_shape
[
1
]
*
lhs_sinv_dtype_bytes
;
rhs_sinv_ptr
+=
rhs_sinv_shape
[
0
]
*
rhs_sinv_shape
[
1
]
*
rhs_sinv_dtype_bytes
;
NVTEScalingMode
nvte_scaling_mode
=
get_nvte_scaling_mode
(
scaling_mode
);
TensorWrapper
lhs_i_
(
nvte_scaling_mode
);
TensorWrapper
rhs_i_
(
nvte_scaling_mode
);
lhs_i_
.
set_rowwise_data
(
lhs_ptr
,
lhs_dtype
,
lhs_shape
);
rhs_i_
.
set_rowwise_data
(
rhs_ptr
,
rhs_dtype
,
rhs_shape
);
lhs_i_
.
set_rowwise_scale_inv
(
lhs_sinv_ptr
,
DType
::
kFloat8E8M0
,
lhs_sinv_shape
);
rhs_i_
.
set_rowwise_scale_inv
(
rhs_sinv_ptr
,
DType
::
kFloat8E8M0
,
rhs_sinv_shape
);
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i_
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i_
));
}
else
{
NVTE_ERROR
(
"Unsupported scaling mode: "
,
static_cast
<
int
>
(
scaling_mode
));
}
auto
out_i_
=
TensorWrapper
(
out_ptr
,
out_shape
,
out_dtype
);
void
*
pre_gelu_ptr
=
nullptr
;
auto
bias_shape
=
std
::
vector
<
size_t
>
{
0
};
auto
pre_gelu_shape
=
std
::
vector
<
size_t
>
{
0
};
if
(
bias_ptr
!=
nullptr
)
bias_shape
[
0
]
=
n
;
if
(
has_bias
)
{
auto
bias_i_get
=
input_list
.
get
<
Buffer_Type
>
(
bias_list_offset
+
i
);
Buffer_Type
bias_i
=
bias_i_get
.
value
();
bias_ptr
=
bias_i
.
untyped_data
();
bias_dtype
=
convert_ffi_datatype_to_te_dtype
(
bias_i
.
element_type
());
bias_shape
[
0
]
=
n
;
}
auto
bias_i
=
TensorWrapper
(
bias_ptr
,
bias_shape
,
bias_dtype
);
if
(
bias_ptr
!=
nullptr
)
bias_ptr
+=
n
*
bias_dtype_bytes
;
auto
pre_gelu_i
=
TensorWrapper
(
pre_gelu_ptr
,
pre_gelu_shape
,
out_dtype
);
out_wrapper_list
.
push_back
(
std
::
move
(
out_i
));
out_wrapper_list
.
push_back
(
std
::
move
(
out_i
_
));
bias_wrapper_list
.
push_back
(
std
::
move
(
bias_i
));
pre_gelu_wrapper_list
.
push_back
(
std
::
move
(
pre_gelu_i
));
...
...
@@ -148,6 +165,10 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
out_list
.
push_back
(
out_wrapper_list
.
back
().
data
());
}
auto
workspace_get
=
output_list
.
get
<
Buffer_Type
>
(
num_gemms
);
Result_Type
workspace
=
workspace_get
.
value
();
uint8_t
*
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace
->
untyped_data
());
size_t
workspace_size
=
workspace
->
dimensions
()[
0
]
/
num_streams
;
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
auto
workspace_i
=
...
...
@@ -165,49 +186,14 @@ Error_Type GroupedGemmImpl(uint8_t *lhs_ptr, const DType &lhs_dtype, uint8_t *lh
return
ffi_with_cuda_error_check
();
}
Error_Type
GroupedGemmFFI
(
cudaStream_t
stream
,
Buffer_Type
lhs_flatten
,
Buffer_Type
lhs_sinv_flatten
,
Buffer_Type
rhs_flatten
,
Buffer_Type
rhs_sinv_flatten
,
Buffer_Type
bias_flatten
,
Buffer_Type
dim_list
,
Result_Type
out_flatten
,
Result_Type
workspace_flatten
,
int64_t
num_gemms
,
int64_t
scaling_mode
)
{
// Inputs
auto
lhs_ptr
=
reinterpret_cast
<
uint8_t
*>
(
lhs_flatten
.
untyped_data
());
auto
rhs_ptr
=
reinterpret_cast
<
uint8_t
*>
(
rhs_flatten
.
untyped_data
());
auto
lhs_sinv_ptr
=
reinterpret_cast
<
uint8_t
*>
(
lhs_sinv_flatten
.
untyped_data
());
auto
rhs_sinv_ptr
=
reinterpret_cast
<
uint8_t
*>
(
rhs_sinv_flatten
.
untyped_data
());
auto
bias_ptr
=
reinterpret_cast
<
uint8_t
*>
(
bias_flatten
.
untyped_data
());
auto
dim_list_ptr
=
reinterpret_cast
<
int32_t
*>
(
dim_list
.
untyped_data
());
auto
lhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_flatten
.
element_type
());
auto
rhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_flatten
.
element_type
());
auto
lhs_sinv_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_sinv_flatten
.
element_type
());
auto
rhs_sinv_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_sinv_flatten
.
element_type
());
auto
bias_dtype
=
convert_ffi_datatype_to_te_dtype
(
bias_flatten
.
element_type
());
// Outputs
auto
out_ptr
=
reinterpret_cast
<
uint8_t
*>
(
out_flatten
->
untyped_data
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
out_flatten
->
element_type
());
auto
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace_flatten
->
untyped_data
());
auto
workspace_size
=
workspace_flatten
->
dimensions
().
back
()
/
num_streams
;
return
GroupedGemmImpl
(
lhs_ptr
,
lhs_dtype
,
lhs_sinv_ptr
,
lhs_sinv_dtype
,
rhs_ptr
,
rhs_dtype
,
rhs_sinv_ptr
,
rhs_sinv_dtype
,
bias_ptr
,
bias_dtype
,
out_ptr
,
out_dtype
,
workspace_ptr
,
workspace_size
,
num_gemms
,
dim_list_ptr
,
scaling_mode
,
stream
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
GroupedGemmHandler
,
GroupedGemmFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// lhs_flatten
.
Arg
<
Buffer_Type
>
()
// lhs_sinv_flatten
.
Arg
<
Buffer_Type
>
()
// rhs_flatten
.
Arg
<
Buffer_Type
>
()
// rhs_sinv_flatten
.
Arg
<
Buffer_Type
>
()
// bias_flatten
.
Arg
<
Buffer_Type
>
()
// dim_list
.
Ret
<
Buffer_Type
>
()
// out_flatten
.
Ret
<
Buffer_Type
>
()
// workspace_flatten
.
RemainingArgs
()
// input list
.
RemainingRets
()
// output list
.
Attr
<
int64_t
>
(
"num_gemms"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
),
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"has_bias"
),
FFI_CudaGraph_Traits
);
}
// namespace jax
...
...
transformer_engine/jax/csrc/extensions/misc.h
View file @
ab3e5a92
...
...
@@ -34,11 +34,34 @@ inline size_t product(const std::vector<size_t> &shape) {
return
ret
;
}
enum
class
Quantize
Axis
{
enum
class
Quantize
Layout
{
ROWWISE
,
COLWISE
,
ROWWISE_COLWISE
,
};
enum
class
JAXX_Scaling_Mode
:
int64_t
{
NO_SCALING
=
0
,
DELAYED_TENSOR_SCALING
=
1
,
MXFP8_1D_SCALING
=
2
,
};
static
NVTEScalingMode
get_nvte_scaling_mode
(
const
JAXX_Scaling_Mode
&
mode
)
{
switch
(
mode
)
{
case
JAXX_Scaling_Mode
::
NO_SCALING
:
return
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
;
break
;
case
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
:
return
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
;
break
;
case
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
:
return
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
;
break
;
default:
NVTE_ERROR
(
"Invalid Scaling Mode "
,
static_cast
<
int
>
(
mode
));
break
;
}
}
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/normalization.cpp
View file @
ab3e5a92
...
...
@@ -14,7 +14,8 @@ namespace jax {
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
NVTE_Norm_Type
norm_type
,
int
scaling_mode
,
NVTE_Norm_Type
norm_type
,
JAXX_Scaling_Mode
scaling_mode
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
is_training
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
...
...
@@ -26,12 +27,11 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
in_dtype
);
auto
rsigma_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
_scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
_scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
nullptr
,
out_dtype
,
input_shape
);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if
(
is_training
&&
_
scaling_mode
==
NVTE_
MXFP8_1D_SCALING
)
{
if
(
is_training
&&
scaling_mode
==
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
{
int
temp
=
1
;
output_tensor
.
set_columnwise_data
(
static_cast
<
void
*>
(
&
temp
),
out_dtype
,
input_shape
);
}
...
...
@@ -47,7 +47,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
nullptr
);
}
else
{
NVTE_CHECK
(
scaling_mode
!=
NVTE
ScalingMode
::
NVTE_
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
NVTE_CHECK
(
scaling_mode
!=
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
"rmsnorm doesn't support zero_centered_gamma."
);
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
epsilon
,
output_tensor
.
data
(),
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
...
...
@@ -64,7 +64,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
int
norm_type
,
bool
zero_centered_gamma
,
double
epsilon
,
int64_t
sm_margin
,
int
scaling_mode
,
bool
is_2x
)
{
int64_t
sm_margin
,
JAXX_Scaling_Mode
scaling_mode
,
bool
is_2x
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
x_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
(
gamma_buf
.
element_type
());
...
...
@@ -80,7 +80,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
*
workspace
=
wkspace_buf
->
untyped_data
();
auto
_scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode
);
auto
_norm_type
=
static_cast
<
NVTE_Norm_Type
>
(
norm_type
);
auto
_is_2x
=
static_cast
<
bool
>
(
is_2x
);
...
...
@@ -105,7 +104,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
_sm_margin
;
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
wkspace_dtype
);
auto
output_tensor
=
TensorWrapper
(
_scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte
_scaling_mode
(
scaling_mode
)
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
input_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
...
...
@@ -117,7 +116,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
scale_inv_buf
->
dimensions
().
back
()});
}
if
(
_
scaling_mode
==
NVTE_
DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
...
...
@@ -142,7 +141,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
}
else
{
NVTE_CHECK
(
scaling_mode
!=
NVTE
ScalingMode
::
NVTE_
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
NVTE_CHECK
(
scaling_mode
!=
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
"rmsnorm doesn't support zero_centered_gamma."
);
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
_epsilon
,
output_tensor
.
data
(),
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
...
...
@@ -170,7 +169,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI,
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"epsilon"
)
.
Attr
<
int64_t
>
(
"sm_margin"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
),
FFI_CudaGraph_Traits
);
...
...
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
ab3e5a92
...
...
@@ -138,17 +138,17 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.
value
(
"RMSNorm"
,
NVTE_Norm_Type
::
RMSNorm
)
.
export_values
();
pybind11
::
enum_
<
NVTE
ScalingMode
>
(
m
,
"
NVTE
_Scaling_Mode"
,
pybind11
::
module_local
())
.
value
(
"N
VTE_DELAYED_TENSOR
_SCALING"
,
NVTE
ScalingMode
::
N
VTE_DELAYED_TENSOR
_SCALING
)
.
value
(
"
NVTE_MXFP8_1D
_SCALING"
,
NVTE
ScalingMode
::
NVTE_MXFP8_1D
_SCALING
)
.
value
(
"
NVTE_INVALI
D_SCALING"
,
NVTE
ScalingMode
::
NVTE_
MXFP8_1D_SCALING
)
pybind11
::
enum_
<
JAXX_
Scaling
_
Mode
>
(
m
,
"
JAXX
_Scaling_Mode"
,
pybind11
::
module_local
())
.
value
(
"N
O
_SCALING"
,
JAXX_
Scaling
_
Mode
::
N
O
_SCALING
)
.
value
(
"
DELAYED_TENSOR
_SCALING"
,
JAXX_
Scaling
_
Mode
::
DELAYED_TENSOR
_SCALING
)
.
value
(
"
MXFP8_1
D_SCALING"
,
JAXX_
Scaling
_
Mode
::
MXFP8_1D_SCALING
)
.
export_values
();
pybind11
::
enum_
<
transformer_engine
::
jax
::
Quantize
Axis
>
(
m
,
"Quantize
Axis
"
,
pybind11
::
enum_
<
transformer_engine
::
jax
::
Quantize
Layout
>
(
m
,
"Quantize
Layout
"
,
pybind11
::
module_local
())
.
value
(
"ROWWISE"
,
transformer_engine
::
jax
::
Quantize
Axis
::
ROWWISE
)
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
Quantize
Axis
::
COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
Quantize
Axis
::
ROWWISE_COLWISE
)
.
value
(
"ROWWISE"
,
transformer_engine
::
jax
::
Quantize
Layout
::
ROWWISE
)
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
Quantize
Layout
::
COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
Quantize
Layout
::
ROWWISE_COLWISE
)
.
export_values
();
}
...
...
transformer_engine/jax/csrc/extensions/quantization.cpp
View file @
ab3e5a92
...
...
@@ -13,7 +13,9 @@ namespace transformer_engine {
namespace
jax
{
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
)
{
DType
in_dtype
,
DType
out_dtype
,
JAXX_Scaling_Mode
scaling_mode
,
QuantizeLayout
q_layout
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
hidden_size
,
batch_size
};
...
...
@@ -27,10 +29,37 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
int
temp
=
0
;
auto
input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
output_shape
,
out_dtype
);
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
q_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
||
q_layout
==
QuantizeLayout
::
ROWWISE
)
{
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
}
if
(
q_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
||
q_layout
==
QuantizeLayout
::
COLWISE
)
{
auto
&
tmp_shape
=
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
tmp_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_columnwise_scale_inv
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
}
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_amax
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
TensorWrapper
dummy_workspace
;
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
...
...
@@ -42,10 +71,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_
Error_Type
DBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
scale_inv_buf
,
Result_Type
trans
_scale_inv_buf
,
Result_Type
amax_
out_
buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
int64_t
scaling_mode
_enum
,
int64_t
quantize_axis_enum
,
bool
is_dbia
s
)
{
Result_Type
scale_inv_buf
,
Result_Type
colwise
_scale_inv_buf
,
Result_Type
amax_buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
JAXX_Scaling_Mode
scaling_mode
,
int64_t
quantize_layout
_enum
,
bool
is_dbias
,
int64_t
flatten_axi
s
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
...
...
@@ -54,8 +83,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto
*
input
=
input_buf
.
untyped_data
();
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
const
quantize_axis
=
static_cast
<
QuantizeAxis
>
(
quantize_axis_enum
);
auto
const
quantize_layout
=
static_cast
<
QuantizeLayout
>
(
quantize_layout_enum
);
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
...
...
@@ -63,9 +91,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
void
*
workspace
=
workspace_buf
->
untyped_data
();
auto
input_dims
=
input_buf
.
dimensions
();
int64_t
input_ndim
=
input_dims
.
size
();
if
(
flatten_axis
<
0
)
flatten_axis
+=
input_ndim
;
NVTE_CHECK
(
flatten_axis
<
input_ndim
&&
flatten_axis
>
0
,
"flatten_axis is out of bounds!"
);
auto
workspace_dims
=
workspace_buf
->
dimensions
();
auto
m
=
product
(
input_dims
,
0
,
input_dims
.
size
()
-
1
);
auto
n
=
input_dims
.
back
(
);
auto
m
=
product
(
input_dims
,
0
,
flatten_axis
);
auto
n
=
product
(
input_dims
,
flatten_axis
,
input_ndim
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
...
...
@@ -73,39 +105,58 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
std
::
vector
<
size_t
>
workspace_shape
{
workspace_dims
.
begin
(),
workspace_dims
.
end
()};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_
scaling_mode
(
scaling_mode
)
);
if
(
quantize_axis
==
QuantizeAxis
::
ROWWISE
||
quantize_axis
==
QuantizeAxis
::
ROWWISE_COLWISE
)
{
if
(
quantize_layout
==
QuantizeLayout
::
ROWWISE
||
quantize_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
)
{
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
}
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax
_out
=
reinterpret_cast
<
float
*>
(
amax_
out_
buf
->
untyped_data
());
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
_out
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax_out
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax_out
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
scale_inv_buf
->
dimensions
(),
flatten_axis
,
scale_inv_buf
->
dimensions
().
size
())});
}
}
}
if
(
quantize_axis
==
QuantizeAxis
::
COLWISE
||
quantize_axis
==
QuantizeAxis
::
ROWWISE_COLWISE
)
{
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
if
(
quantize_layout
==
QuantizeLayout
::
COLWISE
||
quantize_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
)
{
auto
&
tmp_shape
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
output_trans_shape
:
output_shape
;
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
tmp_shape
);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
colwise_scale_inv_buf
=
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
trans_scale_inv_buf
;
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
colwise_scale_inv_buf
;
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
}
else
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
tmp_buf
->
dimensions
(),
0
,
flatten_axis
),
product
(
tmp_buf
->
dimensions
(),
flatten_axis
,
tmp_buf
->
dimensions
().
size
())});
}
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
...
...
@@ -132,9 +183,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"q_axis"
)
.
Attr
<
bool
>
(
"is_dbias"
),
.
Attr
<
JAXX_Scaling_Mode
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"q_layout"
)
.
Attr
<
bool
>
(
"is_dbias"
)
.
Attr
<
int64_t
>
(
"flatten_axis"
),
FFI_CudaGraph_Traits
);
Error_Type
DequantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
...
...
transformer_engine/jax/dense.py
View file @
ab3e5a92
...
...
@@ -15,7 +15,11 @@ import jax
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.quantize
import
QuantizerSet
,
noop_quantizer_set
from
.quantize
import
(
QuantizerSet
,
noop_quantizer_set
,
with_sharding_constraint_by_logical_axes
,
)
def
dense
(
...
...
@@ -23,6 +27,8 @@ def dense(
kernel
:
jnp
.
ndarray
,
bias
:
jnp
.
ndarray
=
None
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_axes
:
Tuple
[
str
,
...]
=
None
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
):
"""Perform dense layer transformation with optional quantization.
...
...
@@ -48,12 +54,12 @@ def dense(
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
output
+=
jnp
.
reshape
(
bias
,
bias_new_shape
)
else
:
output
=
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
)
output
=
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
)
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,))
def
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
):
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
4
,
5
))
def
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
...
...
@@ -64,32 +70,37 @@ def _dense(x, kernel, bias, contracting_dims, quantizer_set):
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
input_axes: Logical axes for sharding the activation input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output
,
_
=
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
)
output
,
_
=
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
)
return
output
def
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
):
def
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
input_axes
,
kernel_axes
,
quantizer_set
):
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims
,
k_contracting_dims
=
contracting_dims
casted_x
=
tex
.
quantize
(
x
,
quantizer_set
.
x
)
casted_kernel
=
tex
.
quantize
(
kernel
,
quantizer_set
.
kernel
)
flatten_axis_x
=
-
len
(
x_contracting_dims
)
flatten_axis_k
=
len
(
k_contracting_dims
)
-
len
(
kernel
.
shape
)
casted_x
=
tex
.
quantize
(
x
,
flatten_axis
=
flatten_axis_x
,
quantizer
=
quantizer_set
.
x
)
casted_x
=
with_sharding_constraint_by_logical_axes
(
casted_x
,
input_axes
)
casted_kernel
=
tex
.
quantize
(
kernel
,
flatten_axis
=
flatten_axis_k
,
quantizer
=
quantizer_set
.
kernel
)
casted_kernel
=
with_sharding_constraint_by_logical_axes
(
casted_kernel
,
kernel_axes
)
# GEMM NN
output
=
tex
.
gemm
(
...
...
@@ -97,6 +108,7 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
casted_kernel
.
get_colwise_tensor
(),
(
x_contracting_dims
,
k_contracting_dims
),
)
use_bias
=
bias
is
not
None
if
use_bias
:
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
...
@@ -109,18 +121,16 @@ def _dense_fwd_rule(x, kernel, bias, contracting_dims, quantizer_set):
kernel
.
shape
,
use_bias
,
quantizer_set
,
flatten_axis_k
,
)
return
output
,
ctx
def
_dense_bwd_rule
(
contracting_dims
,
ctx
,
grad
):
# pylint: disable=unused-argument
def
_dense_bwd_rule
(
contracting_dims
,
input_axes
,
kernel_axes
,
ctx
,
grad
):
# pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Tuple of gradients with respect to inputs
"""
...
...
@@ -133,9 +143,12 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
kernel_shape
,
use_bias
,
quantizer_set
,
flatten_axis_k
,
)
=
ctx
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
flatten_axis
=
flatten_axis_k
,
quantizer
=
quantizer_set
.
dgrad
)
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
...
...
@@ -151,6 +164,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
rowwise_casted_kernel
,
(
g_constracting_dim
,
k_constracting_dim
),
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
input_axes
)
# GEMM TN
# x_non_contracting_dims
...
...
@@ -161,6 +175,7 @@ def _dense_bwd_rule(contracting_dims, ctx, grad): # pylint: disable=unused-argu
wgrad
=
tex
.
gemm
(
colwise_casted_x
,
casted_grad
.
get_colwise_tensor
(),
(
x_constracting_dim
,
g_constracting_dim
)
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
return
dgrad
,
wgrad
,
dbias
,
quantizer_set
...
...
transformer_engine/jax/flax/module.py
View file @
ab3e5a92
...
...
@@ -13,7 +13,6 @@ import jax.numpy as jnp
from
flax
import
linen
as
nn
from
flax.linen
import
partitioning
as
nn_partitioning
from
jax
import
lax
from
jax
import
nn
as
jax_nn
from
jax
import
random
as
jax_random
from
jax.ad_checkpoint
import
checkpoint_name
...
...
@@ -26,8 +25,14 @@ from ..layernorm_mlp import layernorm_mlp
from
..activation
import
activation
from
..softmax
import
softmax
,
SoftmaxType
from
..sharding
import
with_sharding_constraint_by_logical_axes
from
..cpp_extensions
import
is_softmax_kernel_available
from
..cpp_extensions
import
(
is_softmax_kernel_available
,
jax_scaled_softmax
,
jax_scaled_masked_softmax
,
jax_scaled_upper_triang_masked_softmax
,
)
from
..quantize
import
QuantizerFactory
,
QuantizeConfig
,
QuantizeMeta
,
QuantizeMetaSet
,
ScalingMode
from
..sharding
import
get_non_contracting_logical_axes
PRNGKey
=
Any
Shape
=
Tuple
[
int
,
...]
...
...
@@ -167,10 +172,10 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
input_dtype
=
inputs
.
dtype
logits
=
inputs
if
self
.
softmax_type
is
not
SoftmaxType
.
SCALED
and
is_softmax_kernel_available
(
# use primitives
if
is_softmax_kernel_available
(
self
.
softmax_type
,
batch
,
heads
,
q_seqlen
,
k_seqlen
,
input_dtype
):
if
bias
is
not
None
:
logits
=
logits
+
bias
.
astype
(
input_dtype
)
...
...
@@ -179,31 +184,22 @@ class Softmax(nn.Module): # pylint: disable=too-few-public-methods
mask_
=
None
outputs
=
softmax
(
logits
,
mask_
,
self
.
scale_factor
,
self
.
softmax_type
)
# use default jax based implementation
else
:
attention_bias
=
None
if
mask
is
not
None
:
attention_bias
=
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
-
1e10
),
jnp
.
full
(
mask
.
shape
,
0.0
),
)
attention_bias
=
attention_bias
.
astype
(
input_dtype
)
if
bias
is
not
None
:
attention_bias
=
_combine_biases
(
attention_bias
,
bias
)
if
attention_bias
is
not
None
:
logits
=
logits
+
attention_bias
.
astype
(
input_dtype
)
logits
=
logits
+
bias
.
astype
(
input_dtype
)
# For the case that
self.softmax
==
SoftmaxType.SCALED
_UPPER_TRIANG_MASKED
# and kernel is unavailable, then try on pure
scaled
softmax
custom calls.
if
is_
softmax_
kernel_available
(
SoftmaxType
.
SCALED
,
batch
,
heads
,
q_seqlen
,
k_seqlen
,
input_dtype
)
:
outputs
=
softmax
(
logits
,
None
,
self
.
scale_factor
,
SoftmaxType
.
SCALED
)
if
self
.
softmax
_type
is
SoftmaxType
.
SCALED
:
outputs
=
jax_
scaled
_
softmax
(
logits
,
self
.
scale_factor
)
el
if
self
.
softmax_
type
is
SoftmaxType
.
SCALED_MASKED
:
outputs
=
jax_scaled_masked_softmax
(
logits
,
mask
,
self
.
scale_factor
)
elif
self
.
softmax_type
is
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
outputs
=
jax_scaled_upper_triang_masked_
softmax
(
logits
,
self
.
scale_factor
)
else
:
outputs
=
jax_nn
.
softmax
(
logits
*
self
.
scale_factor
)
raise
ValueError
(
f
"Unsupported softmax type:
{
self
.
softmax_type
}
. softmax_type must be [SCALED,"
" SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED]"
)
assert
input_dtype
==
outputs
.
dtype
return
outputs
...
...
@@ -360,7 +356,7 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
).
value
return
QuantizeMeta
(
scale
=
scale
,
amax_history
=
amax_history
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
x_meta
=
generate_quantize_meta
(
"x"
)
kernel_meta
=
generate_quantize_meta
(
"kernel"
)
grad_meta
=
generate_quantize_meta
(
"grad"
)
...
...
@@ -406,6 +402,10 @@ class DenseGeneral(TransformerEngineBase):
:math:`\frac{alpha}{rank} * lora_output`. None means no scaling.
axis: Union[Iterable[int], int], default = -1
An integer tuple with axes to apply the transformation on.
input_axes: Tuple[str, ...], default = None
Indicate the logical axes of sharding constraint to the input, like
(BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert
sharding constraint.
Optimization parameters
-----------------------
...
...
@@ -429,6 +429,7 @@ class DenseGeneral(TransformerEngineBase):
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
dtype
:
DType
=
jnp
.
float32
transpose_batch_sequence
:
bool
=
False
input_axes
:
Tuple
[
str
,
...]
=
()
def
__post_init__
(
self
):
if
self
.
kernel_init
is
None
:
...
...
@@ -460,29 +461,35 @@ class DenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
inputs
.
ndim
)
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
if
self
.
kernel_axes
:
assert
len
(
kernel_shape
)
==
len
(
self
.
kernel_axes
),
(
"Expected len(kernel_shape) to match len(kernel_axes),"
f
"got kernel_shape
{
kernel_shape
}
and kernel_axes
{
self
.
kernel_axes
}
"
)
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
kernel_compute_shape
=
(
reduce
(
operator
.
mul
,
[
inputs
.
shape
[
ax
]
for
ax
in
axis
],
1
),
reduce
(
operator
.
mul
,
features
,
1
),
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_compute_shape
)
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
)
bias
=
bias
.
reshape
(
kernel_compute_shape
[
-
1
]).
astype
(
input_dtype
)
).
astype
(
input_dtype
)
else
:
bias
=
None
quantizer_set
=
self
.
generate_quantizer_set
()
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
y
=
dense
(
inputs
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
quantizer_set
inputs
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
input_axes
,
kernel_axes
=
self
.
kernel_axes
,
quantizer_set
=
quantizer_set
,
)
if
self
.
enable_low_rank_adaptation
:
...
...
@@ -491,20 +498,14 @@ class DenseGeneral(TransformerEngineBase):
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_init_shape
=
(
kernel_compute_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
"lora_a_kernel"
,
self
.
kernel_init
,
lora_a_kernel_
init_
shape
,
lora_a_kernel_shape
,
self
.
dtype
,
axes
=
lora_a_kernel_axes
,
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
...
...
@@ -527,7 +528,6 @@ class DenseGeneral(TransformerEngineBase):
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
assert
y
.
dtype
==
input_dtype
y
=
y
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
features
)
return
y
...
...
@@ -678,6 +678,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert
self
.
axis
==
-
1
,
"Only support axis = =-1 at this moment"
input_dtype
=
inputs
.
dtype
ln_output
=
None
...
...
@@ -692,10 +693,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if
self
.
enable_layernorm
:
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
assert
self
.
axis
==
-
1
# Only support axis = =-1 at this moment
features
=
inputs
.
shape
[
-
1
]
scale
,
ln_bias
=
_create_layernorm_parameters
(
self
.
layernorm_type
,
(
features
,),
...
...
@@ -731,17 +729,12 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_shape
=
tuple
(
y
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
)
+
features
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
kernel_compute_shape
=
(
reduce
(
operator
.
mul
,
[
inputs
.
shape
[
ax
]
for
ax
in
axis
],
1
),
reduce
(
operator
.
mul
,
features
,
1
),
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_compute_shape
)
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
...
...
@@ -756,11 +749,19 @@ class LayerNormDenseGeneral(TransformerEngineBase):
epsilon
=
self
.
epsilon
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
dot_input_axes
=
self
.
dot_input_axes
,
kernel_axes
=
self
.
kernel_axes
,
quantizer_set
=
quantizer_set
,
)
else
:
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_input_axes
)
z
=
dense
(
y
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
quantizer_set
)
z
=
dense
(
y
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
dot_input_axes
,
kernel_axes
=
self
.
kernel_axes
,
quantizer_set
=
quantizer_set
,
)
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
...
...
@@ -768,20 +769,14 @@ class LayerNormDenseGeneral(TransformerEngineBase):
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_init_shape
=
(
kernel_compute_shape
[
0
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_init_shape
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
"lora_a_kernel"
,
self
.
kernel_init
,
lora_a_kernel_
init_
shape
,
lora_a_kernel_shape
,
self
.
dtype
,
axes
=
lora_a_kernel_axes
,
)
lora_a_kernel
=
jnp
.
reshape
(
lora_a_kernel
,
lora_a_kernel_shape
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
...
...
@@ -803,8 +798,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
)
bias
=
bias
.
reshape
(
kernel_compute_shape
[
-
1
]).
astype
(
input_dtype
)
).
astype
(
input_dtype
)
if
bias
is
not
None
:
bias_shape
=
(
1
,)
*
(
z
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
...
@@ -814,7 +808,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
z
=
z
/
self
.
depth_scaling
assert
z
.
dtype
==
input_dtype
,
f
"output_dtype=
{
z
.
dtype
}
, input_dtype=
{
input_dtype
}
"
z
=
z
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
features
)
#
z = z.reshape(*inputs.shape[: self.axis], *features)
return
z
,
ln_output
# dense_output, layer_norm_output
...
...
@@ -989,6 +983,8 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
assert
self
.
axis
==
-
1
,
"Only support axis == -1 at this moment"
ffn1_quantizer_set
=
self
.
generate_quantizer_set
(
"_0"
)
ffn2_quantizer_set
=
self
.
generate_quantizer_set
(
"_1"
)
...
...
@@ -1027,7 +1023,6 @@ class LayerNormMLP(TransformerEngineBase):
)
# LayerNorm
if
self
.
enable_layernorm
:
assert
self
.
axis
==
-
1
# Only support axis == -1 at this moment
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
features
=
inputs
.
shape
[
-
1
]
...
...
@@ -1071,7 +1066,7 @@ class LayerNormMLP(TransformerEngineBase):
num_activations
=
len
(
normalized_acts
)
axis
=
_canonicalize_tuple
(
self
.
axis
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_1_each_shape
=
(
np
.
prod
([
y
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1_each_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1
=
nn_partitioning
.
param_with_axes
(
"wi_kernel"
,
kernel_1_init
,
...
...
@@ -1081,13 +1076,10 @@ class LayerNormMLP(TransformerEngineBase):
self
.
dtype
,
axes
=
self
.
kernel_axes_1
,
)
kernel_1_compute_shape
=
(
reduce
(
operator
.
mul
,
[
y
.
shape
[
ax
]
for
ax
in
axis
],
1
),
num_activations
*
self
.
intermediate_dim
,
)
kernel_1
=
jnp
.
reshape
(
kernel_1
,
kernel_1_compute_shape
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_1
=
kernel_1
.
astype
(
input_dtype
)
hidden_size
=
inputs
.
shape
[
-
1
]
hidden_size_tuple
=
_canonicalize_tuple
(
hidden_size
)
kernel_2_shape
=
(
self
.
intermediate_dim
,)
+
hidden_size_tuple
...
...
@@ -1098,26 +1090,20 @@ class LayerNormMLP(TransformerEngineBase):
self
.
dtype
,
axes
=
self
.
kernel_axes_2
,
)
kernel_2_compute_shape
=
(
self
.
intermediate_dim
,
reduce
(
operator
.
mul
,
hidden_size_tuple
,
1
),
)
kernel_2
=
jnp
.
reshape
(
kernel_2
,
kernel_2_compute_shape
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_2
=
kernel_2
.
astype
(
input_dtype
)
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
if
self
.
use_bias
:
bias_1_shape
=
num_activations
*
self
.
intermediate_dim
bias_1_shape
=
(
num_activations
,
self
.
intermediate_dim
)
bias_1
=
nn_partitioning
.
param_with_axes
(
"wi_bias"
,
self
.
bias_init
,
bias_1_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_1
,
)
bias_1
=
bias_1
.
reshape
(
kernel_1_compute_shape
[
-
1
]).
astype
(
input_dtype
)
).
astype
(
input_dtype
)
bias_2_shape
=
(
hidden_size
,)
bias_2
=
nn_partitioning
.
param_with_axes
(
...
...
@@ -1126,8 +1112,7 @@ class LayerNormMLP(TransformerEngineBase):
bias_2_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_2
,
)
bias_2
=
bias_2
.
reshape
(
kernel_2_compute_shape
[
-
1
]).
astype
(
input_dtype
)
).
astype
(
input_dtype
)
else
:
bias_1
=
None
bias_2
=
None
...
...
@@ -1136,8 +1121,6 @@ class LayerNormMLP(TransformerEngineBase):
ffn2_ckpt_name
=
"ffn2"
if
use_fused_layernorm_mlp
:
assert
self
.
axis
==
-
1
# Only support axis = =-1 at this moment
out
=
layernorm_mlp
(
y
,
scale
,
...
...
@@ -1150,6 +1133,8 @@ class LayerNormMLP(TransformerEngineBase):
norm_input_axes
=
self
.
layernorm_input_axes
,
dot_1_input_axes
=
self
.
dot_1_input_axes
,
dot_2_input_axes
=
self
.
dot_2_input_axes
,
kernel_1_axes
=
self
.
kernel_axes_1
,
kernel_2_axes
=
self
.
kernel_axes_2
,
ffn1_ckpt_name
=
ffn1_ckpt_name
,
ffn2_ckpt_name
=
ffn2_ckpt_name
,
activation_type
=
normalized_acts
,
...
...
@@ -1170,6 +1155,7 @@ class LayerNormMLP(TransformerEngineBase):
epsilon
=
self
.
epsilon
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
dot_input_axes
=
self
.
dot_1_input_axes
,
kernel_axes
=
self
.
kernel_axes_1
,
quantizer_set
=
ffn1_quantizer_set
,
)
else
:
...
...
@@ -1178,35 +1164,31 @@ class LayerNormMLP(TransformerEngineBase):
y
,
kernel_1
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
dot_1_input_axes
,
kernel_axes
=
self
.
kernel_axes_1
,
quantizer_set
=
ffn1_quantizer_set
,
)
dot_1_output_axes
=
(
*
get_non_contracting_logical_axes
(
y
.
ndim
,
self
.
dot_1_input_axes
,
axis
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
self
.
kernel_axes_1
,
contract_ind
),
)
x
=
with_sharding_constraint_by_logical_axes
(
x
,
dot_1_output_axes
)
if
self
.
enable_low_rank_adaptation
:
wi_lora_a_kernel_shape
=
(
kernel_1_compute_shape
[
0
],
num_activations
,
wi_lora_a_kernel_each_shape
=
(
kernel_1_each_shape
[:
len
(
axis
)],
self
.
low_rank_adaptation_dim
,
)
wi_lora_a_kernel_init_shape
=
(
kernel_1_each_shape
[
0
],
num_activations
,
self
.
low_rank_adaptation_dim
,
)
wi_lora_a_kernel_init_each_shape
=
(
kernel_1_each_shape
[
0
],
self
.
low_rank_adaptation_dim
,
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_init_shape
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_each_shape
+
1
)
wi_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
"wi_lora_a_kernel"
,
kernel_1_init
,
num_activations
,
-
1
,
wi_lora_a_kernel_
init_
each_shape
,
-
2
,
wi_lora_a_kernel_each_shape
,
self
.
dtype
,
axes
=
wi_lora_a_kernel_axes
,
)
wi_lora_a_kernel
=
jnp
.
reshape
(
wi_lora_a_kernel
,
wi_lora_a_kernel_shape
)
wi_lora_a_kernel
=
wi_lora_a_kernel
.
astype
(
input_dtype
)
wi_lora_b_kernel_shape
=
(
...
...
@@ -1227,7 +1209,7 @@ class LayerNormMLP(TransformerEngineBase):
x
+=
_apply_low_rank_adaptation
(
y
,
axis
,
num_activations
*
self
.
intermediate_dim
,
(
num_activations
,
self
.
intermediate_dim
)
,
wi_lora_a_kernel
,
wi_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
,
...
...
@@ -1241,11 +1223,12 @@ class LayerNormMLP(TransformerEngineBase):
z
=
activation
(
x
,
normalized_acts
)
else
:
activations
=
[]
x
=
jnp
.
split
(
x
,
num_activations
,
axis
=-
1
)
x
=
jnp
.
split
(
x
,
num_activations
,
axis
=-
2
)
for
idx
,
act_fn
in
enumerate
(
normalized_acts
):
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
activations
.
append
(
x_i
)
z
=
reduce
(
operator
.
mul
,
activations
)
z
=
jnp
.
squeeze
(
z
,
axis
=-
2
)
z
=
z
.
astype
(
input_dtype
)
z
=
nn
.
Dropout
(
...
...
@@ -1259,7 +1242,12 @@ class LayerNormMLP(TransformerEngineBase):
# DenseGeneral 2
out
=
dense
(
z
,
kernel_2
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
ffn2_quantizer_set
z
,
kernel_2
,
contracting_dims
=
(
axis
,
contract_ind
),
input_axes
=
self
.
dot_2_input_axes
,
kernel_axes
=
self
.
kernel_axes_2
,
quantizer_set
=
ffn2_quantizer_set
,
)
if
self
.
enable_low_rank_adaptation
:
...
...
transformer_engine/jax/flax/transformer.py
View file @
ab3e5a92
...
...
@@ -220,11 +220,11 @@ class _UnfusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-
if
mask
is
not
None
:
mask
=
apply_swa_mask
(
mask
)
# Currently cuDNN backend only supports SWA for causal/padding_causal, follow this
if
attn_mask_type
in
[
AttnMaskType
.
CAUSAL_MASK
,
AttnMaskType
.
PADDING_CAUSAL_MASK
]:
return
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
mask
if
attn_mask_type
in
[
AttnMaskType
.
NO_MASK
,
AttnMaskType
.
PADDING_MASK
]:
if
mask
is
not
None
:
return
SoftmaxType
.
SCALED_MASKED
,
mask
if
attn_mask_type
is
AttnMaskType
.
CAUSAL_MASK
:
return
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
mask
if
attn_mask_type
is
AttnMaskType
.
NO_MASK
:
return
SoftmaxType
.
SCALED
,
mask
raise
ValueError
(
f
"Unsupported
{
attn_mask_type
=
}
, supported attn_mask_type="
...
...
@@ -447,6 +447,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
.. note:: THD format only supports 'padding' or 'causal_padding' mask type.
attn_mask_type mask/sequence_descriptor SWA softmax type
--------------------------------------------------------------------------------------------
no_mask None None SCALED
causal None None SCALED_UPPER_TRIANG_MASKED
causal None Yes SCALED_MASKED
padding Required Yes/No SCALED_MASKED
padding_causal Required Yes/No SCALED_MASKED
attn_bias_type: Optional[str], default = None
Type of the attention bias passed in the attention.
Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}.
...
...
transformer_engine/jax/layernorm_dense.py
View file @
ab3e5a92
...
...
@@ -33,10 +33,9 @@ def layernorm_dense(
norm_type
:
str
=
"layernorm"
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes
:
Tuple
[
str
,
...]
=
None
,
# The logic axes of sharding constraint to the dot input.
dot_input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_axes
:
Tuple
[
str
,
...]
=
None
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
)
->
jnp
.
ndarray
:
"""Apply layer normalization followed by dense layer transformation.
...
...
@@ -56,6 +55,7 @@ def layernorm_dense(
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
kernel_axes: Logical axes for sharding the weight matrix
quantizer_set: Set of quantizers for different tensor types
Returns:
...
...
@@ -78,6 +78,7 @@ def layernorm_dense(
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
kernel_axes
,
quantizer_set
,
)
return
output
...
...
@@ -91,6 +92,7 @@ def layernorm_dense(
7
,
8
,
9
,
10
,
),
)
def
_layernorm_dense
(
...
...
@@ -104,6 +106,7 @@ def _layernorm_dense(
epsilon
:
float
,
layernorm_input_axes
:
Tuple
[
str
,
...],
dot_input_axes
:
Tuple
[
str
,
...],
kernel_axes
:
Tuple
[
str
,
...],
quantizer_set
,
):
"""Internal implementation of layernorm_dense with custom VJP.
...
...
@@ -139,6 +142,7 @@ def _layernorm_dense(
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
kernel_axes
,
quantizer_set
,
)
return
output
...
...
@@ -155,6 +159,7 @@ def _layernorm_dense_fwd_rule(
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
kernel_axes
,
quantizer_set
,
):
"""Forward pass rule for layernorm_dense.
...
...
@@ -171,7 +176,6 @@ def _layernorm_dense_fwd_rule(
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
k_contracting_dims
=
(
0
,)
assert
x
.
shape
[
-
1
]
==
kernel
.
shape
[
0
]
assert
len
(
kernel
.
shape
)
==
2
# Otherwise need to merge dims in quantize
x
=
with_sharding_constraint_by_logical_axes
(
x
,
layernorm_input_axes
)
...
...
@@ -184,11 +188,12 @@ def _layernorm_dense_fwd_rule(
norm_type
,
quantizer_set
.
x
,
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_input_axes
)
# Kernel in (hidden_in, hidden_out...)
casted_kernel
=
tex
.
quantize
(
kernel
,
quantizer_set
.
kernel
)
casted_
ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_
ln_out
,
dot_input
_axes
)
flatten_axis
=
1
-
len
(
kernel
.
shape
)
casted_kernel
=
tex
.
quantize
(
kernel
,
flatten_axis
=
flatten_axis
,
quantizer
=
quantizer_set
.
kernel
)
casted_
kernel
=
with_sharding_constraint_by_logical_axes
(
casted_
kernel
,
kernel
_axes
)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
...
...
@@ -217,6 +222,7 @@ def _layernorm_dense_fwd_rule(
k_contracting_dims
,
use_bias
,
quantizer_set
,
flatten_axis
,
)
return
output
,
ctx
...
...
@@ -228,6 +234,7 @@ def _layernorm_dense_bwd_rule(
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
# pylint: disable=unused-argument
kernel_axes
,
ctx
,
grad
,
):
...
...
@@ -256,11 +263,12 @@ def _layernorm_dense_bwd_rule(
k_contracting_dims_in_fwd
,
use_bias
,
quantizer_set
,
flatten_axis
,
)
=
ctx
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_input_axes
)
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
flatten_axis
=
flatten_axis
,
quantizer
=
quantizer_set
.
dgrad
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim
=
tuple
(
...
...
@@ -291,6 +299,8 @@ def _layernorm_dense_bwd_rule(
(
x_constracting_dim
,
g_constracting_dim
),
)
wgrad
=
with_sharding_constraint_by_logical_axes
(
wgrad
,
kernel_axes
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dgrad
,
x
,
...
...
transformer_engine/jax/layernorm_mlp.py
View file @
ab3e5a92
...
...
@@ -23,6 +23,7 @@ from jax.ad_checkpoint import checkpoint_name
from
.
import
cpp_extensions
as
tex
from
.layernorm
import
canonicalize_norm_type
from
.quantize
import
with_sharding_constraint_by_logical_axes
,
QuantizerSet
,
noop_quantizer_set
from
.sharding
import
get_non_contracting_logical_axes
def
layernorm_mlp
(
...
...
@@ -37,6 +38,8 @@ def layernorm_mlp(
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_1_axes
:
Tuple
[
str
,
...]
=
None
,
kernel_2_axes
:
Tuple
[
str
,
...]
=
None
,
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
...
...
@@ -66,6 +69,8 @@ def layernorm_mlp(
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
kernel_1_axes: Logical axes for sharding the first weight matrix
kernel_2_axes: Logical axes for sharding the second weight matrix
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
...
...
@@ -109,6 +114,8 @@ def layernorm_mlp(
norm_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
...
...
@@ -117,7 +124,7 @@ def layernorm_mlp(
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
))
def
_layernorm_mlp
(
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
...
...
@@ -132,6 +139,8 @@ def _layernorm_mlp(
norm_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
kernel_1_axes
:
Tuple
[
str
,
...],
kernel_2_axes
:
Tuple
[
str
,
...],
ffn1_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
...
...
@@ -179,6 +188,8 @@ def _layernorm_mlp(
norm_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
...
...
@@ -201,6 +212,8 @@ def _layernorm_mlp_fwd_rule(
norm_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
...
...
@@ -220,20 +233,21 @@ def _layernorm_mlp_fwd_rule(
Returns:
Tuple of (output, context) for automatic differentiation
"""
del
kernel_2_axes
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (hidden_in, activation_len
*
intermediate)
# Kernel_1 should be in shape of (hidden_in, activation_len
,
intermediate)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert
len
(
kernel_1
.
shape
)
==
2
assert
len
(
kernel_1
.
shape
)
==
3
assert
len
(
kernel_2
.
shape
)
==
2
assert
kernel_1
.
shape
[
1
]
==
kernel_2
.
shape
[
0
]
*
len
(
activation_type
)
assert
kernel_1
.
shape
[
-
2
]
==
len
(
activation_type
)
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
k_contracting_dims
=
(
0
,)
assert
x
.
shape
[
x_contracting_dims
[
0
]]
==
kernel_1
.
shape
[
k_contracting_dims
[
0
]]
assert
kernel_1
.
shape
[
1
]
==
len
(
activation_type
)
*
kernel_2
.
shape
[
0
]
use_bias_1
=
bias_1
is
not
None
use_bias_2
=
bias_1
is
not
None
...
...
@@ -249,11 +263,10 @@ def _layernorm_mlp_fwd_rule(
norm_type
,
quantizer
=
ffn1_quantizer_set
.
x
,
)
casted_kernel_1
=
tex
.
quantize
(
kernel_1
,
quantizer
=
ffn1_quantizer_set
.
kernel
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_1_input_axes
)
casted_kernel_1
=
tex
.
quantize
(
kernel_1
,
flatten_axis
=-
2
,
quantizer
=
ffn1_quantizer_set
.
kernel
)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output
=
tex
.
gemm
(
...
...
@@ -261,6 +274,13 @@ def _layernorm_mlp_fwd_rule(
casted_kernel_1
.
get_colwise_tensor
(),
(
x_contracting_dims
,
k_contracting_dims
),
)
dot_1_output_axes
=
(
*
get_non_contracting_logical_axes
(
x
.
ndim
,
dot_1_input_axes
,
x_contracting_dims
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
kernel_1_axes
,
k_contracting_dims
),
)
dot_1_output
=
with_sharding_constraint_by_logical_axes
(
dot_1_output
,
dot_1_output_axes
)
if
use_bias_1
:
bias_1_shape
=
bias_1
.
shape
bias_1_new_shape
=
(
1
,)
*
(
dot_1_output
.
ndim
-
bias_1
.
ndim
)
+
bias_1_shape
...
...
@@ -283,6 +303,12 @@ def _layernorm_mlp_fwd_rule(
(
x_contracting_dims
,
k_contracting_dims
),
)
dot_2_output_axes
=
(
*
get_non_contracting_logical_axes
(
x
.
ndim
,
dot_2_input_axes
,
x_contracting_dims
),
*
get_non_contracting_logical_axes
(
kernel_2
.
ndim
,
None
,
k_contracting_dims
),
)
dot_2_output
=
with_sharding_constraint_by_logical_axes
(
dot_2_output
,
dot_2_output_axes
)
if
use_bias_2
:
bias_2_shape
=
bias_2
.
shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
...
...
@@ -320,8 +346,10 @@ def _layernorm_mlp_bwd_rule(
norm_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
ffn1_ckpt_name
,
# pylint: disable=unused-argument
ffn2_ckpt_name
,
# pylint: disable=unused-argument
kernel_1_axes
,
kernel_2_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
ctx
,
grad
,
...
...
@@ -339,6 +367,7 @@ def _layernorm_mlp_bwd_rule(
Returns:
Tuple of gradients for all input parameters
"""
del
norm_input_axes
,
ffn1_ckpt_name
,
ffn2_ckpt_name
(
x
,
mu
,
...
...
@@ -369,11 +398,11 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_con
s
tracting_dim_2
=
tuple
(
g_contracting_dim
s
_2
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_2_shape
)
+
len
(
k_contracting_dims_in_fwd
),
grad
.
ndim
)
)
# k_non_contracting_dims
k_con
s
tracting_dim_2
=
tuple
(
k_contracting_dim
s
_2
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_2_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
)
...
...
@@ -382,12 +411,12 @@ def _layernorm_mlp_bwd_rule(
dgrad_2
=
tex
.
gemm
(
casted_grad
.
get_rowwise_tensor
(),
rowwise_casted_kernel_2
,
(
g_con
s
tracting_dim_2
,
k_con
s
tracting_dim_2
),
(
g_contracting_dim
s
_2
,
k_contracting_dim
s
_2
),
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
x_con
s
tracting_dim
=
g_con
s
tracting_dim
=
tuple
(
x_contracting_dim
s
=
g_contracting_dim
s
=
tuple
(
range
(
0
,
len
(
x
.
shape
)
-
len
(
x_contracting_dims_in_fwd
))
)
...
...
@@ -396,8 +425,9 @@ def _layernorm_mlp_bwd_rule(
wgrad_2
=
tex
.
gemm
(
colwise_casted_act_out
,
casted_grad
.
get_colwise_tensor
(),
(
x_con
s
tracting_dim
,
g_con
s
tracting_dim
),
(
x_contracting_dim
s
,
g_contracting_dim
s
),
)
wgrad_2
=
with_sharding_constraint_by_logical_axes
(
wgrad_2
,
kernel_2_axes
)
casted_dact_out
,
dbias_1
=
tex
.
quantize_dact_dbias
(
dgrad_2
,
...
...
@@ -408,11 +438,12 @@ def _layernorm_mlp_bwd_rule(
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1
=
tuple
(
range
(
dgrad_2
.
ndim
-
len
(
kernel_1_shape
)
+
len
(
k_contracting_dims_in_fwd
),
dgrad_2
.
ndim
)
dact_out_ndim
=
casted_dact_out
.
get_rowwise_tensor
().
data
.
ndim
g_contracting_dims_1
=
tuple
(
range
(
dact_out_ndim
-
len
(
kernel_1_shape
)
+
len
(
k_contracting_dims_in_fwd
),
dact_out_ndim
)
)
# k_non_contracting_dims
k_con
s
tracting_dim_1
=
tuple
(
k_contracting_dim
s
_1
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_1_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
)
...
...
@@ -420,19 +451,21 @@ def _layernorm_mlp_bwd_rule(
dgrad_1
=
tex
.
gemm
(
casted_dact_out
.
get_rowwise_tensor
(),
rowwise_casted_kernel_1
,
(
g_con
s
tracting_dim_1
,
k_con
s
tracting_dim_1
),
(
g_contracting_dim
s
_1
,
k_contracting_dim
s
_1
),
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
norm
_input_axes
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
dot_1
_input_axes
)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1
=
tex
.
gemm
(
colwise_casted_ln_out
,
casted_dact_out
.
get_colwise_tensor
(),
(
x_con
s
tracting_dim
,
g_con
s
tracting_dim
),
(
x_contracting_dim
s
,
g_contracting_dim
s
),
)
wgrad_1
=
with_sharding_constraint_by_logical_axes
(
wgrad_1
,
kernel_1_axes
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dgrad_1
,
x
,
...
...
transformer_engine/jax/praxis/__init__.py
deleted
100644 → 0
View file @
a8d19fd9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from
.module
import
FusedSoftmax
,
LayerNorm
from
.module
import
LayerNormLinear
,
LayerNormMLP
,
Linear
,
TransformerEngineBaseLayer
from
.transformer
import
DotProductAttention
,
MultiHeadAttention
from
.transformer
import
RelativePositionBiases
,
TransformerLayer
from
..flax.transformer
import
TransformerLayerType
transformer_engine/jax/praxis/module.py
deleted
100644 → 0
View file @
a8d19fd9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from
dataclasses
import
field
from
functools
import
partial
from
typing
import
Callable
,
Iterable
,
Sequence
,
Tuple
,
Union
from
praxis
import
pax_fiddle
from
praxis.base_layer
import
init_var
from
praxis.base_layer
import
BaseLayer
,
WeightInit
,
WeightHParams
,
WeightHParamsCollection
from
praxis.layers
import
flax_adapter
from
praxis.pytypes
import
JTensor
from
..fp8
import
FP8Helper
from
..flax.module
import
DenseGeneral
,
LayerNormDenseGeneral
from
..flax.module
import
LayerNorm
as
flax_LayerNorm
from
..flax.module
import
LayerNormMLP
as
flax_LayerNormMLP
from
..flax.module
import
Softmax
from
..softmax
import
SoftmaxType
def
_generate_ln_scale_init
(
scale_init
):
if
scale_init
is
not
None
:
return
TransformerEngineBaseLayer
.
generate_params_init
(
"scale"
,
scale_init
)
return
scale_init
class
TransformerEngineBaseLayer
(
BaseLayer
):
"""TransformerEngineBaseLayer"""
logical_axes_rules
:
Tuple
[
Tuple
,
...]
=
None
@
staticmethod
def
generate_params_init
(
name
:
str
,
initializer
:
WeightInit
):
"""generate_params_init"""
def
kernel_init
(
key
,
shape
,
dtype
):
wp
=
WeightHParams
(
shape
=
shape
,
init
=
initializer
,
dtype
=
dtype
)
return
init_var
(
wp
,
key
,
name
)
return
kernel_init
def
create_layer
(
self
,
name
,
flax_module_cls
):
"""create_layer"""
fp8_collection_map
=
{
FP8Helper
.
FP8_COLLECTION_NAME
:
[
WeightHParamsCollection
.
SKIP_LP_REGULARIZATION
,
WeightHParamsCollection
.
OVERWRITE_WITH_GRADIENT
,
WeightHParamsCollection
.
DISALLOW_BFLOAT16_CONVERSION
,
]
}
flax_module_p
=
pax_fiddle
.
Config
(
flax_adapter
.
FlaxModuleAdapter
,
module_factory_method
=
flax_module_cls
,
logical_axes_rules
=
self
.
logical_axes_rules
,
var_collection_map
=
fp8_collection_map
,
ici_mesh_shape
=
self
.
ici_mesh_shape
,
dcn_mesh_shape
=
self
.
dcn_mesh_shape
,
mesh_axis_names
=
self
.
mesh_axis_names
,
)
self
.
create_child
(
name
,
flax_module_p
.
clone
())
class
LayerNorm
(
TransformerEngineBaseLayer
):
"""LayerNorm"""
epsilon
:
float
=
1e-6
layernorm_type
:
str
=
"layernorm"
zero_centered_gamma
:
bool
=
False
scale_init
:
WeightInit
=
None
scale_axes
:
Tuple
[
str
,
...]
=
()
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes
:
Tuple
[
str
,
...]
=
()
transpose_batch_sequence
:
bool
=
False
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
ln_cls
=
partial
(
flax_LayerNorm
,
epsilon
=
self
.
epsilon
,
layernorm_type
=
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
scale_init
=
_generate_ln_scale_init
(
self
.
scale_init
),
scale_axes
=
self
.
scale_axes
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
)
self
.
create_layer
(
"layer_norm"
,
ln_cls
)
def
__call__
(
self
,
x
:
JTensor
)
->
JTensor
:
"""__call__"""
return
self
.
layer_norm
(
x
)
class
FusedSoftmax
(
TransformerEngineBaseLayer
):
"""FusedSoftmax"""
scale_factor
:
float
=
1.0
softmax_type
:
SoftmaxType
=
SoftmaxType
.
SCALED
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
fused_softmax_cls
=
partial
(
Softmax
,
scale_factor
=
self
.
scale_factor
,
softmax_type
=
self
.
softmax_type
)
self
.
create_layer
(
"fused_softmax"
,
fused_softmax_cls
)
def
__call__
(
self
,
x
:
JTensor
,
mask
:
JTensor
=
None
,
bias
:
JTensor
=
None
)
->
JTensor
:
"""__call__"""
return
self
.
fused_softmax
(
x
,
mask
,
bias
)
class
Linear
(
TransformerEngineBaseLayer
):
"""Linear"""
out_features
:
int
=
512
kernel_axes
:
Tuple
[
str
,
...]
=
()
use_bias
:
bool
=
True
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
dense_general_cls
=
partial
(
DenseGeneral
,
features
=
self
.
out_features
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
kernel_axes
=
self
.
kernel_axes
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
)
self
.
create_layer
(
"linear"
,
dense_general_cls
)
def
__call__
(
self
,
x
:
JTensor
)
->
JTensor
:
"""__call__"""
return
self
.
linear
(
x
)
class
LayerNormLinear
(
TransformerEngineBaseLayer
):
"""LayerNormLinear"""
out_features
:
int
=
512
enable_layernorm
:
bool
=
True
layernorm_type
:
str
=
"layernorm"
epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
scale_init
:
WeightInit
=
None
scale_axes
:
Tuple
[
str
,
...]
=
()
ln_bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
1.0
)
)
ln_bias_axes
:
Tuple
[
str
,
...]
=
()
kernel_axes
:
Tuple
[
str
,
...]
=
()
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
depth_scaling
:
float
=
None
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
ln_dense_general_cls
=
partial
(
LayerNormDenseGeneral
,
features
=
self
.
out_features
,
enable_layernorm
=
self
.
enable_layernorm
,
layernorm_type
=
self
.
layernorm_type
,
epsilon
=
self
.
epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
scale_init
=
_generate_ln_scale_init
(
self
.
scale_init
),
scale_axes
=
self
.
scale_axes
,
ln_bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
self
.
ln_bias_init
),
ln_bias_axes
=
self
.
ln_bias_axes
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
kernel_axes
=
self
.
kernel_axes
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes
=
self
.
bias_axes
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
depth_scaling
=
self
.
depth_scaling
,
)
self
.
create_layer
(
"ln_linear"
,
ln_dense_general_cls
)
def
__call__
(
self
,
x
:
JTensor
)
->
JTensor
:
"""__call__"""
return
self
.
ln_linear
(
x
)
class
LayerNormMLP
(
TransformerEngineBaseLayer
):
"""LayerNormMLP"""
intermediate_dim
:
int
=
2048
enable_layernorm
:
bool
=
True
layernorm_type
:
str
=
"layernorm"
epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
scale_init
:
WeightInit
=
None
scale_axes
:
Tuple
[
str
,
...]
=
()
ln_bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
1.0
)
)
ln_bias_axes
:
Tuple
[
str
,
...]
=
()
kernel_axes_1
:
Tuple
[
str
,
...]
=
()
kernel_axes_2
:
Tuple
[
str
,
...]
=
()
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
bias_axes_1
:
Tuple
[
str
,
...]
=
()
bias_axes_2
:
Tuple
[
str
,
...]
=
()
enable_low_rank_adaptation
:
bool
=
False
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
return_layernorm_output
:
bool
=
True
activations
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"relu"
,)
intermediate_dropout_rate
:
float
=
0.1
intermediate_hidden_dropout_dims
:
Sequence
[
int
]
=
()
axis
:
Union
[
Iterable
[
int
],
int
]
=
-
1
transpose_batch_sequence
:
bool
=
False
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
ln_mlp_cls
=
partial
(
flax_LayerNormMLP
,
intermediate_dim
=
self
.
intermediate_dim
,
enable_layernorm
=
self
.
enable_layernorm
,
layernorm_type
=
self
.
layernorm_type
,
epsilon
=
self
.
epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
scale_init
=
_generate_ln_scale_init
(
self
.
scale_init
),
scale_axes
=
self
.
scale_axes
,
ln_bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"ln_bias"
,
self
.
ln_bias_init
),
ln_bias_axes
=
self
.
ln_bias_axes
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
kernel_axes_1
=
self
.
kernel_axes_1
,
kernel_axes_2
=
self
.
kernel_axes_2
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
bias_axes_1
=
self
.
bias_axes_1
,
bias_axes_2
=
self
.
bias_axes_2
,
enable_low_rank_adaptation
=
self
.
enable_low_rank_adaptation
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
return_layernorm_output
=
self
.
return_layernorm_output
,
activations
=
self
.
activations
,
intermediate_dropout_rate
=
self
.
intermediate_dropout_rate
,
intermediate_hidden_dropout_dims
=
self
.
intermediate_hidden_dropout_dims
,
axis
=
self
.
axis
,
dtype
=
self
.
dtype
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
)
self
.
create_layer
(
"ln_mlp"
,
ln_mlp_cls
)
def
__call__
(
self
,
x
:
JTensor
,
deterministic
:
bool
=
False
)
->
JTensor
:
"""__call__"""
return
self
.
ln_mlp
(
x
,
deterministic
)
transformer_engine/jax/praxis/transformer.py
deleted
100644 → 0
View file @
a8d19fd9
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from
dataclasses
import
field
from
functools
import
partial
from
typing
import
Optional
,
Sequence
,
Tuple
import
warnings
from
praxis
import
pax_fiddle
from
praxis.base_layer
import
WeightInit
from
praxis.pytypes
import
JTensor
from
.module
import
TransformerEngineBaseLayer
from
..flax.transformer
import
TransformerLayerType
from
..flax.transformer
import
DotProductAttention
as
flax_DotProductAttention
from
..flax.transformer
import
MultiHeadAttention
as
flax_MultiHeadAttention
from
..flax.transformer
import
RelativePositionBiases
as
flax_RelativePositionBiases
from
..flax.transformer
import
TransformerLayer
as
flax_TransformerLayer
from
..attention
import
AttnBiasType
,
AttnMaskType
class
RelativePositionBiases
(
TransformerEngineBaseLayer
):
"""RelativePositionBiases"""
num_buckets
:
int
=
32
max_distance
:
int
=
128
num_attention_heads
:
int
=
64
embedding_init
:
WeightInit
=
None
embedding_axes
:
Tuple
[
str
,
...]
=
()
@
staticmethod
def
generate_embedding_init
(
init
,
num_attention_heads
,
num_buckets
):
"""generate_embedding_init"""
embedding_init
=
init
if
embedding_init
is
None
:
rb_stddev
=
(
num_attention_heads
*
num_buckets
)
**
-
0.5
embedding_init
=
WeightInit
.
Gaussian
(
rb_stddev
)
return
embedding_init
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
embedding_init
=
RelativePositionBiases
.
generate_embedding_init
(
self
.
embedding_init
,
self
.
num_attention_heads
,
self
.
num_buckets
)
rpb_cls
=
partial
(
flax_RelativePositionBiases
,
num_buckets
=
self
.
num_buckets
,
max_distance
=
self
.
max_distance
,
num_attention_heads
=
self
.
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
embedding_init
),
embedding_axes
=
self
.
embedding_axes
,
dtype
=
self
.
dtype
,
)
self
.
create_layer
(
"relative_position_bias"
,
rpb_cls
)
def
__call__
(
self
,
q_seqlen
:
JTensor
,
k_seqlen
:
JTensor
,
bidirectional
:
bool
=
True
)
->
JTensor
:
"""__call__"""
return
self
.
relative_position_bias
(
q_seqlen
,
k_seqlen
,
bidirectional
)
class
DotProductAttention
(
TransformerEngineBaseLayer
):
"""DotProductAttention"""
head_dim
:
int
=
0
num_attention_heads
:
int
=
0
num_gqa_groups
:
Optional
[
int
]
=
None
attention_dropout
:
float
=
0.0
attn_mask_type
:
AttnMaskType
=
"causal"
attn_bias_type
:
AttnBiasType
=
None
dropout_rng_name
:
str
=
"dropout"
float32_logits
:
bool
=
False
qkv_layout
:
str
=
"bshd_bshd_bshd"
scale_factor
:
Optional
[
float
]
=
None
transpose_batch_sequence
:
bool
=
True
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
assert
self
.
head_dim
>
0
,
f
"
{
self
.
head_dim
=
}
"
assert
self
.
num_attention_heads
>
0
,
f
"
{
self
.
num_attention_heads
=
}
"
dpa_cls
=
partial
(
flax_DotProductAttention
,
head_dim
=
self
.
head_dim
,
num_attention_heads
=
self
.
num_attention_heads
,
num_gqa_groups
=
self
.
num_gqa_groups
,
attn_mask_type
=
self
.
attn_mask_type
,
attn_bias_type
=
self
.
attn_bias_type
,
attention_dropout
=
self
.
attention_dropout
,
dtype
=
self
.
dtype
,
dropout_rng_name
=
self
.
dropout_rng_name
,
float32_logits
=
self
.
float32_logits
,
qkv_layout
=
self
.
qkv_layout
,
scale_factor
=
self
.
scale_factor
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
window_size
=
self
.
window_size
,
)
self
.
create_layer
(
"dot_product_attention"
,
dpa_cls
)
def
__call__
(
self
,
query
:
JTensor
,
key
:
JTensor
,
value
:
JTensor
,
mask
:
Optional
[
JTensor
]
=
None
,
bias
:
Optional
[
JTensor
]
=
None
,
*
,
deterministic
:
bool
=
False
,
)
->
JTensor
:
"""__call__"""
return
self
.
dot_product_attention
(
query
,
key
,
value
,
mask
,
bias
,
deterministic
=
deterministic
)
class
MultiHeadAttention
(
TransformerEngineBaseLayer
):
"""MultiHeadAttention"""
head_dim
:
int
=
0
num_attention_heads
:
int
=
0
num_gqa_groups
:
Optional
[
int
]
=
None
attention_dropout
:
float
=
0.0
dropout_rng_name
:
str
=
"dropout"
input_layernorm
:
bool
=
True
layernorm_type
:
str
=
"layernorm"
layernorm_epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
return_layernorm_output
:
bool
=
False
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
attn_mask_type
:
str
=
"causal"
attn_bias_type
:
Optional
[
str
]
=
None
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
"consecutive"
low_rank_adaptation_scope
:
str
=
"none"
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
True
enable_sequence_parallel
:
bool
=
False
scale_attn_logits
:
bool
=
False
scaled_query_init
:
bool
=
True
float32_logits
:
bool
=
False
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
# Deprecated parameters
num_heads
:
Optional
[
int
]
=
None
dropout_rate
:
Optional
[
float
]
=
None
output_layernorm
:
Optional
[
bool
]
=
None
apply_residual_connection_post_layernorm
:
Optional
[
bool
]
=
None
fuse_qkv
:
Optional
[
bool
]
=
None
def
__post_init__
(
self
):
# Deal with the deprecated parameters
if
self
.
num_heads
is
not
None
:
self
.
num_attention_heads
=
self
.
num_heads
warnings
.
warn
(
f
"
{
__class__
}
.num_heads is deprecated. It will be removed recently. "
f
"Please uses
{
__class__
}
.num_attention_heads as the new API."
,
DeprecationWarning
,
)
if
self
.
dropout_rate
is
not
None
:
self
.
attention_dropout
=
self
.
dropout_rate
warnings
.
warn
(
f
"
{
__class__
}
.dropout_rate is deprecated. It will be removed recently. "
f
"Please use
{
__class__
}
.attention_dropout as the new API."
,
DeprecationWarning
,
)
if
self
.
apply_residual_connection_post_layernorm
is
not
None
:
warnings
.
warn
(
f
"
{
__class__
}
.apply_residual_connection_post_layernorm is deprecated. "
f
"It will be removed recently, please use
{
__class__
}
.return_layernorm_output."
,
DeprecationWarning
,
)
if
self
.
fuse_qkv
is
not
None
:
warnings
.
warn
(
f
"
{
__class__
}
.fuse_qkv is deprecated. It will be removed recently. "
f
"Please use
{
__class__
}
.fuse_qkv_params as the new API."
,
DeprecationWarning
,
)
assert
self
.
output_layernorm
is
None
,
(
f
"
{
__class__
}
.output_layernorm is deprecated. It will be removed recently. "
f
"Please use
{
__class__
}
.input_layernorm for controlling whether to apply layernorm."
)
if
self
.
num_gqa_groups
is
None
:
self
.
num_gqa_groups
=
self
.
num_heads
super
().
__post_init__
()
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
assert
self
.
head_dim
>
0
,
f
"
{
self
.
head_dim
=
}
"
assert
self
.
num_attention_heads
>
0
,
f
"
{
self
.
num_attention_heads
=
}
"
mha_cls
=
partial
(
flax_MultiHeadAttention
,
dtype
=
self
.
dtype
,
head_dim
=
self
.
head_dim
,
num_attention_heads
=
self
.
num_attention_heads
,
num_gqa_groups
=
self
.
num_gqa_groups
,
attention_dropout
=
self
.
attention_dropout
,
dropout_rng_name
=
self
.
dropout_rng_name
,
input_layernorm
=
self
.
input_layernorm
,
layernorm_type
=
self
.
layernorm_type
,
layernorm_epsilon
=
self
.
layernorm_epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
return_layernorm_output
=
self
.
return_layernorm_output
,
kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"kernel"
,
self
.
params_init
),
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
attn_mask_type
=
self
.
attn_mask_type
,
attn_bias_type
=
self
.
attn_bias_type
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scaled_query_init
=
self
.
scaled_query_init
,
float32_logits
=
self
.
float32_logits
,
window_size
=
self
.
window_size
,
)
self
.
create_layer
(
"multi_head_attn"
,
mha_cls
)
def
__call__
(
self
,
inputs_q
:
JTensor
,
inputs_kv
:
JTensor
,
mask
:
Optional
[
JTensor
]
=
None
,
bias
:
Optional
[
JTensor
]
=
None
,
*
,
decode
:
bool
=
False
,
deterministic
:
bool
=
False
,
)
->
JTensor
:
"""__call__"""
return
self
.
multi_head_attn
(
inputs_q
,
inputs_kv
,
mask
,
bias
,
decode
=
decode
,
deterministic
=
deterministic
)
class
TransformerLayer
(
TransformerEngineBaseLayer
):
"""TransformerLayer"""
hidden_size
:
int
=
512
mlp_hidden_size
:
int
=
2048
num_attention_heads
:
int
=
8
num_gqa_groups
:
Optional
[
int
]
=
None
layernorm_type
:
str
=
"layernorm"
layernorm_epsilon
:
float
=
1e-6
zero_centered_gamma
:
bool
=
False
hidden_dropout
:
float
=
0.1
hidden_dropout_dims
:
Sequence
[
int
]
=
()
attention_dropout
:
float
=
0.1
intermediate_dropout
:
float
=
0.1
intermediate_dropout_dims
:
Sequence
[
int
]
=
()
dropout_rng_name
:
str
=
"dropout"
mlp_activations
:
Sequence
[
str
]
=
(
"relu"
,)
use_bias
:
bool
=
False
bias_init
:
WeightInit
=
field
(
# pylint: disable=invalid-field-call
default_factory
=
partial
(
WeightInit
.
Constant
,
scale
=
0.0
)
)
apply_residual_connection_post_layernorm
:
bool
=
False
output_layernorm
:
bool
=
False
float32_attention_logits
:
bool
=
False
layer_type
:
TransformerLayerType
=
TransformerLayerType
.
ENCODER
self_attn_mask_type
:
str
=
"causal"
self_attn_bias_type
:
Optional
[
str
]
=
None
enable_rotary_pos_emb
:
bool
=
False
rotary_pos_emb_windows
:
Tuple
[
int
,
int
]
=
(
1
,
10000
)
rotary_pos_emb_group_method
:
str
=
"consecutive"
low_rank_adaptation_scope
:
str
=
"none"
low_rank_adaptation_dim
:
int
=
32
low_rank_adaptation_alpha
:
float
=
None
enable_relative_embedding
:
bool
=
True
relative_embedding
:
pax_fiddle
.
Config
[
RelativePositionBiases
]
=
pax_fiddle
.
template_field
(
None
)
drop_path
:
float
=
0.0
fuse_qkv_params
:
bool
=
True
transpose_batch_sequence
:
bool
=
False
enable_sequence_parallel
:
bool
=
False
scale_attn_logits
:
bool
=
False
scaled_query_init
:
bool
=
True
window_size
:
Optional
[
Tuple
[
int
,
int
]]
=
None
def
__post_init__
(
self
):
if
self
.
num_gqa_groups
is
None
:
self
.
num_gqa_groups
=
self
.
num_attention_heads
super
().
__post_init__
()
def
setup
(
self
)
->
None
:
"""setup"""
super
().
setup
()
relative_embedding_flax_module
=
None
if
self
.
enable_relative_embedding
and
self
.
relative_embedding
is
not
None
:
assert
self
.
relative_embedding
.
num_attention_heads
==
self
.
num_attention_heads
,
(
"TransformerLayer.relative_embedding.num_attention_heads shoule be"
"the same as TransformerLayer.num_attention_heads."
)
embedding_init
=
RelativePositionBiases
.
generate_embedding_init
(
self
.
relative_embedding
.
embedding_init
,
self
.
relative_embedding
.
num_attention_heads
,
self
.
relative_embedding
.
num_buckets
,
)
relative_embedding_flax_module
=
flax_RelativePositionBiases
(
num_buckets
=
self
.
relative_embedding
.
num_buckets
,
max_distance
=
self
.
relative_embedding
.
max_distance
,
num_attention_heads
=
self
.
relative_embedding
.
num_attention_heads
,
embedding_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"rel_embedding"
,
embedding_init
),
embedding_axes
=
self
.
relative_embedding
.
embedding_axes
,
dtype
=
self
.
relative_embedding
.
dtype
,
)
transformerlayer_cls
=
partial
(
flax_TransformerLayer
,
dtype
=
self
.
dtype
,
hidden_size
=
self
.
hidden_size
,
mlp_hidden_size
=
self
.
mlp_hidden_size
,
num_attention_heads
=
self
.
num_attention_heads
,
num_gqa_groups
=
self
.
num_gqa_groups
,
layernorm_type
=
self
.
layernorm_type
,
layernorm_epsilon
=
self
.
layernorm_epsilon
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
hidden_dropout
=
self
.
hidden_dropout
,
hidden_dropout_dims
=
self
.
hidden_dropout_dims
,
attention_dropout
=
self
.
attention_dropout
,
intermediate_dropout
=
self
.
intermediate_dropout
,
intermediate_dropout_dims
=
self
.
intermediate_dropout_dims
,
dropout_rng_name
=
self
.
dropout_rng_name
,
mha_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mha_kernel"
,
self
.
params_init
),
mlp_kernel_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"mlp_kernel"
,
self
.
params_init
),
mlp_activations
=
self
.
mlp_activations
,
use_bias
=
self
.
use_bias
,
bias_init
=
TransformerEngineBaseLayer
.
generate_params_init
(
"bias"
,
self
.
bias_init
),
apply_residual_connection_post_layernorm
=
self
.
apply_residual_connection_post_layernorm
,
output_layernorm
=
self
.
output_layernorm
,
float32_attention_logits
=
self
.
float32_attention_logits
,
layer_type
=
self
.
layer_type
,
self_attn_mask_type
=
self
.
self_attn_mask_type
,
self_attn_bias_type
=
self
.
self_attn_bias_type
,
enable_rotary_pos_emb
=
self
.
enable_rotary_pos_emb
,
rotary_pos_emb_windows
=
self
.
rotary_pos_emb_windows
,
rotary_pos_emb_group_method
=
self
.
rotary_pos_emb_group_method
,
low_rank_adaptation_scope
=
self
.
low_rank_adaptation_scope
,
low_rank_adaptation_dim
=
self
.
low_rank_adaptation_dim
,
low_rank_adaptation_alpha
=
self
.
low_rank_adaptation_alpha
,
enable_relative_embedding
=
self
.
enable_relative_embedding
,
relative_embedding
=
relative_embedding_flax_module
,
drop_path
=
self
.
drop_path
,
fuse_qkv_params
=
self
.
fuse_qkv_params
,
transpose_batch_sequence
=
self
.
transpose_batch_sequence
,
enable_sequence_parallel
=
self
.
enable_sequence_parallel
,
scale_attn_logits
=
self
.
scale_attn_logits
,
scaled_query_init
=
self
.
scaled_query_init
,
window_size
=
self
.
window_size
,
)
self
.
create_layer
(
"transformerlayer"
,
transformerlayer_cls
)
def
__call__
(
self
,
inputs
:
JTensor
,
encoded
:
JTensor
=
None
,
attention_mask
:
JTensor
=
None
,
encoder_decoder_mask
:
JTensor
=
None
,
deterministic
:
bool
=
False
,
decode
:
bool
=
False
,
max_decode_length
:
bool
=
None
,
)
->
JTensor
:
"""__call__"""
return
self
.
transformerlayer
(
inputs
,
encoded
,
attention_mask
,
encoder_decoder_mask
,
deterministic
,
decode
,
max_decode_length
,
)
transformer_engine/jax/quantize/dequantizer.py
View file @
ab3e5a92
...
...
@@ -57,26 +57,35 @@ class Dequantizer:
data
=
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
data_shape
=
data
.
shape
scale
=
scaled_tensor
.
scale_inv
.
view
(
jnp
.
uint8
).
astype
(
jnp
.
float32
)
flatten_axis
=
scaled_tensor
.
flatten_axis
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
if
flatten_axis
<
0
else
flatten_axis
assert
(
0
<
flatten_axis
<
len
(
data_shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
scale_shape
=
scaled_tensor
.
scaling_mode
.
get_scale_shape
(
scaled_tensor
.
data
.
shape
,
scaled_tensor
.
is_colwise
,
is_padded
=
False
data
_
shape
,
scaled_tensor
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
scale
=
jax
.
lax
.
slice
(
scale
,
[
0
]
*
len
(
scale_shape
),
scale_shape
)
# slice out the padding
data
=
data
.
reshape
(
*
data_shape
[:
-
2
],
scale_shape
[
-
2
],
int
(
data_shape
[
-
2
]
/
scale_shape
[
-
2
]),
*
data_shape
[:
flatten_axis
-
1
],
scale_shape
[
flatten_axis
-
1
],
int
(
data_shape
[
flatten_axis
-
1
]
/
scale_shape
[
flatten_axis
-
1
]),
*
data_shape
[
flatten_axis
:
-
1
],
scale_shape
[
-
1
],
int
(
data_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
scale
=
jnp
.
expand_dims
(
scale
,
axis
=
(
-
1
,
-
3
))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
scale
=
jnp
.
expand_dims
(
scale
,
axis
=
(
flatten_axis
+
2
-
2
,
-
1
))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return
jnp
.
asarray
(
data
*
jnp
.
power
(
2
,
scale
-
127
),
scaled_tensor
.
dq_dtype
).
reshape
(
data_shape
)
funcs
=
{
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
ScalingMode
.
DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
}
@
staticmethod
...
...
transformer_engine/jax/quantize/helper.py
View file @
ab3e5a92
...
...
@@ -27,7 +27,14 @@ from transformer_engine.jax.sharding import global_shard_guard, MeshResource
from
.scaling_modes
import
ScalingMode
from
..
import
cpp_extensions
as
tex
__all__
=
[
"QuantizeConfig"
,
"fp8_autocast"
,
"is_fp8_available"
,
"update_collections"
]
__all__
=
[
"QuantizeConfig"
,
"fp8_autocast"
,
"is_fp8_available"
,
"update_collections"
,
"get_delayed_scaling"
,
"NVTE_FP8_COLLECTION_NAME"
,
]
_is_fp8_available
=
None
_reason_for_no_fp8
=
""
...
...
@@ -87,15 +94,15 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch
=
get_device_compute_capability
(
gpu_id
)
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
return
_check_delayed_scaling_fp8_support
(
gpu_arch
)
if
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
return
_check_block_scaling_fp8_support
(
gpu_arch
)
return
(
False
,
"Unsupported scaling_mode!"
)
def
is_fp8_available
(
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
gpu_id
=
None
,
)
->
Tuple
[
bool
,
str
]:
"""Check if FP8 is available for the given scaling mode and GPU.
...
...
@@ -172,37 +179,12 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
ValueError: If the recipe type is not supported
"""
if
isinstance
(
fp8_recipe
,
recipe
.
DelayedScaling
):
return
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
return
ScalingMode
.
DELAYED_TENSOR_SCALING
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
return
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
return
ScalingMode
.
MXFP8_1D_SCALING
raise
ValueError
(
"Invalid fp8_recipe!"
)
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
Collection
:
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert
isinstance
(
original
,
(
dict
,
FrozenDict
))
assert
isinstance
(
new
,
(
dict
,
FrozenDict
))
frozen_original
=
FrozenDict
(
original
)
if
not
isinstance
(
original
,
FrozenDict
)
else
original
for
key
in
new
:
if
key
in
frozen_original
:
frozen_original
,
_
=
frozen_original
.
pop
(
key
)
new_coll
=
FrozenDict
({
**
new
,
**
frozen_original
})
if
not
isinstance
(
original
,
FrozenDict
):
new_coll
=
new_coll
.
unfreeze
()
return
new_coll
class
QuantizeConfig
:
"""Configuration class for quantization settings.
...
...
@@ -227,7 +209,7 @@ class QuantizeConfig:
INITIALIZED
=
False
MARGIN
:
float
=
0.0
COLLECTION_NAME
:
str
=
"
quantize
_meta"
COLLECTION_NAME
:
str
=
"
fp8
_meta
s
"
FP8_FORMAT
:
recipe
.
Format
=
recipe
.
Format
.
HYBRID
FWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
0
]
BWD_DTYPE
:
DType
=
_format2dtypes
(
recipe
.
Format
.
HYBRID
)[
1
]
...
...
@@ -235,7 +217,7 @@ class QuantizeConfig:
FP8_2X_ACC_DGRAD
:
bool
=
False
FP8_2X_ACC_WGRAD
:
bool
=
False
IF_QUANTIZE_2X
:
bool
=
False
SCALING_MODE
:
ScalingMode
=
ScalingMode
.
NVTE_
NO_SCALING
SCALING_MODE
:
ScalingMode
=
ScalingMode
.
NO_SCALING
# DelayedScaling
AMAX_HISTORY_LEN
:
int
=
1024
...
...
@@ -271,11 +253,11 @@ class QuantizeConfig:
cls
.
MARGIN
=
0.0
cls
.
FP8_FORMAT
=
recipe
.
Format
.
HYBRID
cls
.
FWD_DTYPE
,
cls
.
BWD_DTYPE
=
_format2dtypes
(
cls
.
FP8_FORMAT
)
cls
.
SCALING_MODE
=
ScalingMode
.
NVTE_
NO_SCALING
cls
.
SCALING_MODE
=
ScalingMode
.
NO_SCALING
cls
.
FP8_2X_ACC_FPROP
=
False
cls
.
FP8_2X_ACC_DGRAD
=
False
cls
.
FP8_2X_ACC_WGRAD
=
False
cls
.
SCALING_MODE
=
ScalingMode
.
NVTE_
NO_SCALING
cls
.
SCALING_MODE
=
ScalingMode
.
NO_SCALING
cls
.
IF_QUANTIZE_2X
=
False
# DelayedScaling
cls
.
AMAX_HISTORY_LEN
=
1024
...
...
@@ -414,3 +396,56 @@ def fp8_autocast(
yield
finally
:
Config
.
finalize
()
def
get_delayed_scaling
():
r
"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo
=
(
"max"
if
QuantizeConfig
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
else
"most_recent"
)
return
recipe
.
DelayedScaling
(
margin
=
int
(
QuantizeConfig
.
MARGIN
),
fp8_format
=
QuantizeConfig
.
FP8_FORMAT
,
amax_history_len
=
QuantizeConfig
.
AMAX_HISTORY_LEN
,
amax_compute_algo
=
amax_compute_algo
,
)
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
Collection
:
r
"""Update collections with new values while preserving original structure.
Args:
new: New collection of values to add/update
original: Original collection to update
Returns:
Updated collection with new values merged with original
Raises:
AssertionError: If either collection is not a dict or FrozenDict
"""
assert
isinstance
(
original
,
(
dict
,
FrozenDict
))
assert
isinstance
(
new
,
(
dict
,
FrozenDict
))
frozen_original
=
FrozenDict
(
original
)
if
not
isinstance
(
original
,
FrozenDict
)
else
original
for
key
in
new
:
if
key
in
frozen_original
:
frozen_original
,
_
=
frozen_original
.
pop
(
key
)
new_coll
=
FrozenDict
({
**
new
,
**
frozen_original
})
if
not
isinstance
(
original
,
FrozenDict
):
new_coll
=
new_coll
.
unfreeze
()
return
new_coll
NVTE_FP8_COLLECTION_NAME
=
QuantizeConfig
.
COLLECTION_NAME
transformer_engine/jax/quantize/quantizer.py
View file @
ab3e5a92
...
...
@@ -14,7 +14,7 @@ from typing import Union, Optional
import
jax
import
jax.numpy
as
jnp
from
jax.tree_util
import
register_pytree_node_class
from
transformer_engine_jax
import
Quantize
Axis
from
transformer_engine_jax
import
Quantize
Layout
from
.scaling_modes
import
ScalingMode
from
.tensor
import
ScaledTensor1x
,
ScaledTensor2x
,
ScaledTensorFactory
...
...
@@ -24,7 +24,7 @@ from .helper import (
)
__all__
=
[
"Quantize
Axis
"
,
"Quantize
Layout
"
,
"Quantizer"
,
"QuantizerSet"
,
"DelayedScaleQuantizer"
,
...
...
@@ -45,12 +45,12 @@ class Quantizer(ABC):
Attributes:
q_dtype: The data type for quantized values
scaling_mode: The scaling mode to use for quantization
q_
axis
: The quantization axis (row-wise, column-wise, or both)
q_
layout
: The quantization axis (row-wise, column-wise, or both)
"""
q_dtype
:
jnp
.
dtype
scaling_mode
:
ScalingMode
q_
axis
:
Quantize
Axis
q_
layout
:
Quantize
Layout
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
...
...
@@ -59,7 +59,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations
"""
children
=
()
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
axis
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
layout
)
return
(
children
,
aux_data
)
@
classmethod
...
...
@@ -85,30 +85,31 @@ class Quantizer(ABC):
Returns:
True if using both row-wise and column-wise quantization
"""
return
self
.
q_
axis
==
Quantize
Axis
.
ROWWISE_COLWISE
return
self
.
q_
layout
==
Quantize
Layout
.
ROWWISE_COLWISE
@
abstractmethod
def
get_layout
(
self
)
->
str
:
"""Get the data layout.
def
get_
data_
layout
(
self
)
->
str
:
"""Get the data
data_
layout.
Returns:
Data layout in string format
Data
data_
layout in string format
"""
@
abstractmethod
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Core quantization function to be implemented by subclasses.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values, default is x.dtype
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
):
def
quantize
(
self
,
x
,
is_rowwise
=
False
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
):
"""Quantize a tensor using the internal _quantize_func().
Args:
...
...
@@ -116,21 +117,26 @@ class Quantizer(ABC):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
if
(
is_rowwise
and
is_colwise
)
or
self
.
is_2x2x
():
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
colwise_tensor
=
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
colwise_tensor
=
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
if
is_colwise
:
return
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
)
return
self
.
_quantize_func
(
x
,
is_colwise
=
True
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
return
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
):
def
get_scale_shapes
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
):
"""Get shapes for scale tensors.
Args:
...
...
@@ -140,7 +146,7 @@ class Quantizer(ABC):
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
)
return
self
.
scaling_mode
.
get_scale_shape_2x
(
data_shape
,
is_padded
,
flatten_axis
)
def
get_scale_dtype
(
self
):
"""Get the data type for scale tensors.
...
...
@@ -161,13 +167,13 @@ class DelayedScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_
axis
: Quantization axis (default: ROWWISE_COLWISE)
q_
layout
: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
q_
axis
:
Quantize
Axis
=
Quantize
Axis
.
ROWWISE_COLWISE
scaling_mode
:
ScalingMode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
q_
layout
:
Quantize
Layout
=
Quantize
Layout
.
ROWWISE_COLWISE
scale
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
ones
((
1
,),
jnp
.
float32
))
amax_history
:
jnp
.
ndarray
=
field
(
...
...
@@ -181,35 +187,37 @@ class DelayedScaleQuantizer(Quantizer):
Tuple of (children, aux_data) for tree operations
"""
children
=
(
self
.
scale
,
self
.
amax_history
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
axis
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_
layout
)
return
(
children
,
aux_data
)
def
get_layout
(
self
)
->
str
:
"""Get the data layout string.
def
get_
data_
layout
(
self
)
->
str
:
"""Get the data
data_
layout string.
Returns:
Data layout in string format
Data
data_
layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
layout
=
"NT"
if
self
.
q_axis
==
QuantizeAxis
.
ROWWISE_COLWISE
:
return
layout
if
self
.
q_axis
==
QuantizeAxis
.
ROWWISE
:
return
layout
[
0
]
if
self
.
q_axis
==
QuantizeAxis
.
COLWISE
:
return
layout
[
1
]
raise
ValueError
(
f
"Invalid q_axis:
{
self
.
q_axis
}
"
)
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
data_layout
=
"NT"
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE_COLWISE
:
return
data_layout
if
self
.
q_layout
==
QuantizeLayout
.
ROWWISE
:
return
data_layout
[
0
]
if
self
.
q_layout
==
QuantizeLayout
.
COLWISE
:
return
data_layout
[
1
]
raise
ValueError
(
f
"Invalid q_layout:
{
self
.
q_layout
}
"
)
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
...
...
@@ -232,9 +240,12 @@ class DelayedScaleQuantizer(Quantizer):
scale_inv
=
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
def
quantize
(
self
,
x
,
is_rowwise
:
bool
=
None
,
is_colwise
:
bool
=
None
,
dq_dtype
=
None
):
def
quantize
(
self
,
x
,
is_rowwise
:
bool
=
None
,
is_colwise
:
bool
=
None
,
dq_dtype
=
None
,
flatten_axis
=-
1
):
"""Quantize a tensor using the internal _quantize_func().
Args:
...
...
@@ -242,32 +253,40 @@ class DelayedScaleQuantizer(Quantizer):
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
if
flatten_axis
<
0
:
flatten_axis
+=
x
.
ndim
assert
0
<
flatten_axis
<
x
.
ndim
,
"flatten_axis is out of bounds!"
is_rowwise
=
(
is_rowwise
if
is_rowwise
is
not
None
else
(
self
.
q_
axis
==
Quantize
Axis
.
ROWWISE
or
self
.
is_2x2x
())
else
(
self
.
q_
layout
==
Quantize
Layout
.
ROWWISE
or
self
.
is_2x2x
())
)
is_colwise
=
(
is_colwise
if
is_colwise
is
not
None
else
(
self
.
q_
axis
==
Quantize
Axis
.
COLWISE
or
self
.
is_2x2x
())
else
(
self
.
q_
layout
==
Quantize
Layout
.
COLWISE
or
self
.
is_2x2x
())
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
)
rowwise_tensor
=
self
.
_quantize_func
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
colwise_tensor
=
None
if
is_colwise
:
colwise_tensor
=
ScaledTensorFactory
.
create_1x
(
data
=
jnp
.
transpose
(
rowwise_tensor
.
data
,
(
-
1
,
*
range
(
rowwise_tensor
.
data
.
ndim
-
1
))),
data
=
jnp
.
transpose
(
rowwise_tensor
.
data
,
(
*
range
(
flatten_axis
,
x
.
ndim
),
*
range
(
flatten_axis
))
),
scale_inv
=
rowwise_tensor
.
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
is_colwise
=
True
,
layout
=
"T"
,
data_layout
=
"T"
,
flatten_axis
=
flatten_axis
,
)
if
is_colwise
and
is_rowwise
:
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
...
...
@@ -353,46 +372,56 @@ class BlockScaleQuantizer(Quantizer):
Attributes:
scaling_mode: Set to NVTE_MXFP8_1D_SCALING
q_
axis
: Quantization axis (default: ROWWISE_COLWISE)
q_
layout
: Quantization axis (default: ROWWISE_COLWISE)
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
q_
axis
:
Quantize
Axis
=
Quantize
Axis
.
ROWWISE_COLWISE
scaling_mode
:
ScalingMode
=
ScalingMode
.
MXFP8_1D_SCALING
q_
layout
:
Quantize
Layout
=
Quantize
Layout
.
ROWWISE_COLWISE
def
get_layout
(
self
)
->
str
:
"""Get the data layout string.
def
get_
data_
layout
(
self
)
->
str
:
"""Get the data
data_
layout string.
Returns:
Data layout in string format
Data
data_
layout in string format
"""
if
self
.
is_2x2x
():
return
"NN"
return
"N"
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
)
->
ScaledTensor1x
:
def
_quantize_func
(
self
,
x
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Quantize function helper for block scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
# TODO(Phuong): use quantize_func from JAX
if
flatten_axis
<
0
:
flatten_axis
=
x
.
ndim
+
flatten_axis
assert
(
0
<=
flatten_axis
<
x
.
ndim
),
f
"Invalid flatten_axis:
{
flatten_axis
}
for tensor of shape
{
x
.
shape
}
"
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
x_shape
=
x
.
shape
scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
x_shape
,
is_colwise
,
is_padded
=
False
)
scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
x_shape
,
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
scale_dtype
=
self
.
scaling_mode
.
get_scale_dtype
()
x
=
x
.
reshape
(
*
x_shape
[:
-
2
],
scale_shape
[
-
2
],
int
(
x_shape
[
-
2
]
/
scale_shape
[
-
2
]),
*
x_shape
[:
flatten_axis
-
1
],
scale_shape
[
flatten_axis
-
1
],
int
(
x_shape
[
flatten_axis
-
1
]
/
scale_shape
[
flatten_axis
-
1
]),
*
x_shape
[
flatten_axis
:
-
1
],
scale_shape
[
-
1
],
int
(
x_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
),
axis
=
(
-
3
,
-
1
),
keepdims
=
True
)
amax
=
jnp
.
max
(
jnp
.
abs
(
x
),
axis
=
(
flatten_axis
+
2
-
2
,
-
1
),
keepdims
=
True
)
MAX
=
jnp
.
finfo
(
self
.
q_dtype
).
max
.
astype
(
jnp
.
float32
)
scales
=
amax
.
astype
(
jnp
.
float32
)
/
MAX
...
...
@@ -409,6 +438,7 @@ class BlockScaleQuantizer(Quantizer):
self
.
scaling_mode
,
is_colwise
=
is_colwise
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
def
_cast_to_e8m0_with_rounding_up
(
self
,
scales
):
...
...
@@ -500,8 +530,8 @@ class QuantizerFactory:
"""
quantizer_type_map
=
{
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
DelayedScaleQuantizer
,
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
BlockScaleQuantizer
,
ScalingMode
.
DELAYED_TENSOR_SCALING
:
DelayedScaleQuantizer
,
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScaleQuantizer
,
}
@
staticmethod
...
...
@@ -509,7 +539,7 @@ class QuantizerFactory:
n_quantizers
:
int
=
1
,
scaling_mode
:
ScalingMode
=
None
,
q_dtype
:
jnp
.
dtype
=
None
,
q_
axis
:
Quantize
Axis
=
None
,
q_
layout
:
Quantize
Layout
=
None
,
**
kwargs
,
)
->
Quantizer
:
"""Create one or more quantizers with specified parameters.
...
...
@@ -518,15 +548,17 @@ class QuantizerFactory:
n_quantizers: Number of quantizers to create
scaling_mode: Scaling mode to use
q_dtype: Quantization data type
q_axis: Quantization axis
q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor
**kwargs: Additional arguments for quantizer initialization
Returns:
A single quantizer or tuple of quantizers
"""
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
# assert scaling_mode != ScalingMode.NVTE_INVALID_SCALING
if
scaling_mode
in
(
ScalingMode
.
NVTE_NO_SCALING
,
ScalingMode
.
NVTE_INVALID_SCALING
):
assert
isinstance
(
scaling_mode
,
ScalingMode
),
"Invalid scaling_mode type"
# import pdb; pdb.set_trace()
if
scaling_mode
==
ScalingMode
.
NO_SCALING
:
quantizers
=
[
None
]
*
n_quantizers
else
:
quantizers
=
[]
...
...
@@ -534,7 +566,7 @@ class QuantizerFactory:
quantizer_type
=
QuantizerFactory
.
quantizer_type_map
.
get
(
scaling_mode
)
quantizers
.
append
(
quantizer_type
(
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
,
**
kwargs
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
,
**
kwargs
)
)
return
quantizers
[
0
]
if
len
(
quantizers
)
==
1
else
tuple
(
quantizers
)
...
...
@@ -554,11 +586,11 @@ class QuantizerFactory:
A QuantizerSet instance
"""
if
is_2x2x
:
q_
axis
_x
=
q_
axis
_kernel
=
q_
axis
_dgrad
=
Quantize
Axis
.
ROWWISE_COLWISE
q_
layout
_x
=
q_
layout
_kernel
=
q_
layout
_dgrad
=
Quantize
Layout
.
ROWWISE_COLWISE
else
:
q_
axis
_x
=
Quantize
Axis
.
ROWWISE
q_
axis
_kernel
=
Quantize
Axis
.
COLWISE
q_
axis
_dgrad
=
None
q_
layout
_x
=
Quantize
Layout
.
ROWWISE
q_
layout
_kernel
=
Quantize
Layout
.
COLWISE
q_
layout
_dgrad
=
None
if
"quantize_meta_set"
in
kwargs
:
quantize_meta_set
=
kwargs
.
get
(
"quantize_meta_set"
)
...
...
@@ -577,9 +609,11 @@ class QuantizerFactory:
else
:
args_x
=
args_kernel
=
args_grad
=
{}
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_axis_x
,
**
args_x
)
q_kernel
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_axis_kernel
,
**
args_kernel
)
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_axis_dgrad
,
**
args_grad
)
q_x
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_layout_x
,
**
args_x
)
q_kernel
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
fwd_dtype
,
q_layout_kernel
,
**
args_kernel
)
q_dgrad
=
QuantizerFactory
.
create
(
1
,
scaling_mode
,
bwd_dtype
,
q_layout_dgrad
,
**
args_grad
)
return
QuantizerSet
(
x
=
q_x
,
kernel
=
q_kernel
,
dgrad
=
q_dgrad
)
@
staticmethod
...
...
@@ -618,4 +652,4 @@ class QuantizerFactory:
return
q_set
[
0
]
if
len
(
q_set
)
==
1
else
tuple
(
q_set
)
noop_quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
NVTE_
NO_SCALING
)
noop_quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
ScalingMode
.
NO_SCALING
)
transformer_engine/jax/quantize/scaling_modes.py
View file @
ab3e5a92
...
...
@@ -16,11 +16,33 @@ from typing import Tuple, Dict
from
functools
import
reduce
import
operator
from
jax.experimental.custom_partitioning
import
CompoundFactor
from
jax.tree_util
import
register_pytree_node_class
import
jax.numpy
as
jnp
from
transformer_engine_jax
import
JAXX_Scaling_Mode
__all__
=
[
"ScalingMode"
]
__all__
=
[
"QuantizeShardyRules"
,
"ScalingMode"
]
@
dataclass
class
QuantizeShardyRules
:
"""Information necessary to shard scale tensors with Shardy.
Attributes:
input_spec: Specification for the input axes
rowwise_rule: Sharding rule for the row-wise scale tensor, depends on
the axes in `input_spec`
colwise_rule: Likewise for the column-wise scale tensor.
factor_sizes: For block scaling, contains the block size factor, which is
used in `input_spec`.
"""
input_spec
:
Tuple
[
str
]
rowwise_rule
:
Tuple
[
str
]
colwise_rule
:
Tuple
[
str
]
factor_sizes
:
Dict
[
str
,
int
]
class
ScalingModeMetadataImpl
(
ABC
):
...
...
@@ -40,7 +62,11 @@ class ScalingModeMetadataImpl(ABC):
@
abstractmethod
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors.
...
...
@@ -48,11 +74,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
@
abstractmethod
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
class
DelayedScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for delayed scaling mode.
...
...
@@ -69,7 +110,11 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
jnp
.
float32
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors in delayed scaling.
...
...
@@ -77,6 +122,7 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being scaled
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors - (1,)
...
...
@@ -84,6 +130,23 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
del
data_shape
,
is_colwise
return
(
1
,)
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
flatten_axis: Axis along which data can be flattened to 2D for quantization.
Returns:
The Shardy rules for the scaling mode
"""
del
flatten_axis
input_spec
=
tuple
(
f
"x
{
i
}
"
for
i
in
range
(
input_rank
))
return
QuantizeShardyRules
(
input_spec
,
(
unique_var
,),
(
unique_var
,),
{})
class
BlockScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for block scaling mode.
...
...
@@ -113,8 +176,35 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
"""
return
jnp
.
float8_e8m0fnu
def
_apply_scale_shape_correction
(
self
,
data_shape
,
n_scale_blocks
,
scale_block_dim
):
"""Remove excess padding from the scale shape and return the shape with respect to the original data shape."""
if
len
(
data_shape
)
>
1
:
# handle last dim
assert
data_shape
[
-
1
]
%
scale_block_dim
==
0
last
=
data_shape
[
-
1
]
//
scale_block_dim
scale_shape
=
(
last
,)
assert
n_scale_blocks
%
last
==
0
n_scale_blocks
//=
last
# handle middle dim, exclude first and last
for
mid
in
reversed
(
data_shape
[
1
:
-
1
]):
scale_shape
=
(
mid
,)
+
scale_shape
assert
n_scale_blocks
%
mid
==
0
n_scale_blocks
//=
mid
scale_shape
=
(
n_scale_blocks
,)
+
scale_shape
else
:
scale_shape
=
(
n_scale_blocks
,)
assert
len
(
scale_shape
)
==
len
(
data_shape
),
f
"scale_shape
{
scale_shape
}
, data_shape
{
data_shape
}
"
return
scale_shape
def
get_scale_shape
(
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
self
,
data_shape
:
Tuple
[
int
,
...],
is_colwise
:
bool
=
False
,
is_padded
:
bool
=
True
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
int
,
...]:
"""Get the shape for scale tensors in block scaling.
...
...
@@ -122,6 +212,7 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
...
...
@@ -135,38 +226,87 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
block_x
,
block_y
=
self
.
_block_dims
alignment_x
,
alignment_y
=
block_alignment
seq_axis
=
len
(
data_shape
)
-
2
if
flatten_axis
<
0
:
flatten_axis
=
len
(
data_shape
)
+
flatten_axis
assert
(
data_shape
[
seq_axis
]
%
block_x
==
0
),
f
"Input data of shape
{
data_shape
}
should be padded by
{
block_x
}
in axes=
{
seq_axis
}
"
0
<
flatten_axis
<
len
(
data_shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
data_shape
}
"
assert
data_shape
[
flatten_axis
-
1
]
%
block_x
==
0
,
(
f
"Data shape
{
data_shape
}
should be divisible by block_x
{
block_x
}
in axis"
f
"
{
flatten_axis
-
1
}
"
)
assert
(
data_shape
[
-
1
]
%
block_y
==
0
),
f
"
Input data of
shape
{
data_shape
}
should be
padded b
y
{
block_y
}
in axis -1"
),
f
"
Data
shape
{
data_shape
}
should be
divisible by block_
y
{
block_y
}
in axis -1"
# NOTE: this overpads if dim > 2 and dims before seq_axis are greater than 1
n_block_seq
=
data_shape
[
seq_axis
]
//
block_x
n_block_y
=
data_shape
[
-
1
]
//
block_y
flattened_first_dim
=
reduce
(
operator
.
mul
,
data_shape
[:
flatten_axis
],
1
)
flattened_last_dim
=
reduce
(
operator
.
mul
,
data_shape
[
flatten_axis
:],
1
)
n_flat_first_dim
=
reduce
(
operator
.
mul
,
data_shape
[:
seq_axis
],
1
)
*
n_block_seq
assert
flattened_first_dim
%
block_x
==
0
,
(
f
"Flattened first dim - mutiplication of axes=
{
tuple
(
range
(
0
,
flatten_axis
))
}
of shape"
f
"
{
data_shape
}
- should be divisible by block_x
{
block_x
}
"
)
assert
flattened_last_dim
%
block_y
==
0
,
(
"Flattened last dim - mutiplication of"
f
" axes=
{
tuple
(
range
(
flatten_axis
,
len
(
data_shape
)))
}
of shape
{
data_shape
}
- should be"
f
" divisible by block_y
{
block_y
}
"
)
# Padding
n_flat_first_dim
=
((
n_flat_first_dim
+
alignment_x
-
1
)
//
alignment_x
)
*
alignment_x
n_block_y
=
((
n_block_y
+
alignment_y
-
1
)
//
alignment_y
)
*
alignment_y
n_block_x
=
int
(
flattened_first_dim
/
block_x
)
n_block_y
=
int
(
flattened_last_dim
/
block_y
)
out_shape
=
()
for
i
in
range
(
seq_axis
):
d
=
data_shape
[
i
]
out_shape
+=
(
d
,)
assert
n_flat_first_dim
%
d
==
0
n_flat_first_dim
//=
d
# padding
n_block_x
=
int
(((
n_block_x
+
alignment_x
-
1
)
//
alignment_x
)
*
alignment_x
)
n_block_y
=
int
(((
n_block_y
+
alignment_y
-
1
)
//
alignment_y
)
*
alignment_y
)
out_shape
+=
(
n_flat_first_dim
,
n_block_y
)
first_dim_scale_shape
=
self
.
_apply_scale_shape_correction
(
data_shape
[:
flatten_axis
],
n_block_x
,
block_x
)
last_dim_scale_shape
=
self
.
_apply_scale_shape_correction
(
data_shape
[
flatten_axis
:],
n_block_y
,
block_y
)
return
out_shape
return
(
*
first_dim_scale_shape
,
*
last_dim_scale_shape
)
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
)
->
QuantizeShardyRules
:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
# (Phuong: Map the NVTEScalingMode value to the ScalingMode
Returns:
The Shardy rules for the scaling mode
"""
input_spec
=
[
f
"x
{
i
}
"
for
i
in
range
(
input_rank
)]
# We have to use two different factors in the two CompoundFactors because of Shardy
# verifier requirements, even though they are the same.
rowwise_var
=
unique_var
colwise_var
=
f
"
{
unique_var
}
_"
input_spec
[
flatten_axis
-
1
]
=
CompoundFactor
(
colwise_var
,
"block_size_colwise"
)
input_spec
[
-
1
]
=
CompoundFactor
(
rowwise_var
,
"block_size_rowwise"
)
# The rowwise and colwise scale tensors should be sharded the same way as the input.
# However, we need to adjust the dimensions where the block scaling factor applies.
rowwise
=
input_spec
.
copy
()
rowwise
[
-
1
]
=
rowwise_var
colwise
=
input_spec
.
copy
()
colwise
[
flatten_axis
-
1
]
=
colwise_var
# This implementation needs to be updated for different block dims.
assert
self
.
_block_dims
==
(
1
,
32
)
return
QuantizeShardyRules
(
tuple
(
input_spec
),
tuple
(
rowwise
),
tuple
(
colwise
),
{
"block_size_rowwise"
:
32
,
"block_size_colwise"
:
32
},
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -175,16 +315,14 @@ class ScalingMode(Enum):
"""Enumeration of tensor scaling modes with their corresponding metadata implementations.
This class defines the available scaling modes for tensor quantization:
- NVTE_DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- NVTE_MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NVTE_INVALID_SCALING: Invalid scaling mode
- NVTE_NO_SCALING: No scaling applied
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- NO_SCALING: No scaling applied
"""
NVTE_DELAYED_TENSOR_SCALING
=
0
NVTE_MXFP8_1D_SCALING
=
1
NVTE_INVALID_SCALING
=
2
NVTE_NO_SCALING
=
3
NO_SCALING
=
JAXX_Scaling_Mode
.
NO_SCALING
DELAYED_TENSOR_SCALING
=
JAXX_Scaling_Mode
.
DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING
=
JAXX_Scaling_Mode
.
MXFP8_1D_SCALING
def
_get_impl
(
self
)
->
ScalingModeMetadataImpl
:
"""Get the implementation for this scaling mode.
...
...
@@ -208,34 +346,54 @@ class ScalingMode(Enum):
"""
return
self
.
_get_impl
().
get_scale_dtype
()
def
get_scale_shape_2x
(
self
,
data_shape
,
is_padded
=
True
)
->
Tuple
[
Tuple
[
int
]]:
def
get_scale_shape_2x
(
self
,
data_shape
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
Tuple
[
int
]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
False
,
is_padded
=
is_padded
data_shape
,
is_colwise
=
False
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
)
colwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
True
,
is_padded
=
is_padded
,
flatten_axis
=
flatten_axis
)
colwise_scale_shape
=
self
.
get_scale_shape
(
data_shape
,
is_colwise
=
True
,
is_padded
=
is_padded
)
return
(
rowwise_scale_shape
,
colwise_scale_shape
)
def
get_scale_shape
(
self
,
data_shape
,
is_colwise
,
is_padded
=
True
)
->
Tuple
[
int
]:
def
get_scale_shape
(
self
,
data_shape
,
is_colwise
,
is_padded
=
True
,
flatten_axis
=-
1
)
->
Tuple
[
int
]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
return
self
.
_get_impl
().
get_scale_shape
(
data_shape
,
is_colwise
,
is_padded
)
return
self
.
_get_impl
().
get_scale_shape
(
data_shape
,
is_colwise
,
is_padded
,
flatten_axis
)
def
get_shardy_sharding_rules
(
self
,
input_rank
,
unique_var
,
flatten_axis
=-
1
)
->
Tuple
[
Tuple
[
str
]]:
"""Sharding rules for the input and (row, col)wise scale tensors.
Args:
input_rank: The rank of the input tensor (for which we produce the scale tensor)
unique_var: An otherwise unused Shardy variable name prefix
Returns:
The Shardy rules for the scaling mode
"""
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_rank
,
unique_var
,
flatten_axis
)
def
__eq__
(
self
,
other
):
"""Compare this scaling mode with another.
...
...
@@ -273,8 +431,8 @@ class ScalingMode(Enum):
SCALING_MODES_TO_IMPL
:
Dict
[
ScalingMode
,
ScalingModeMetadataImpl
]
=
{
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
BlockScalingModeMetadataImpl
(
block_dims
=
(
1
,
32
)),
ScalingMode
.
DELAYED_TENSOR_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScalingModeMetadataImpl
(
block_dims
=
(
1
,
32
)),
# WAR
ScalingMode
.
NVTE_
NO_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
NO_SCALING
:
DelayedScalingModeMetadataImpl
(),
}
transformer_engine/jax/quantize/tensor.py
View file @
ab3e5a92
...
...
@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
import
jax.numpy
as
jnp
from
jax.tree_util
import
register_pytree_node_class
from
transformer_engine_jax
import
Quantize
Axis
from
transformer_engine_jax
import
Quantize
Layout
from
.scaling_modes
import
ScalingMode
from
.dequantizer
import
Dequantizer
...
...
@@ -84,6 +84,17 @@ class ScaledTensor(ABC):
ValueError: If called on a tensor that doesn't support column-wise access
"""
@
abstractmethod
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
@
register_pytree_node_class
@
dataclass
...
...
@@ -100,7 +111,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype: The data type for dequantized values
_dq_func: The dequantization function
is_colwise: Whether the tensor uses column-wise quantization
layout: The layout specification for the tensor
data_layout: The data_layout specification for the tensor
flatten_axis: The quantization axis for the tensor
"""
data
:
jnp
.
ndarray
...
...
@@ -109,7 +121,8 @@ class ScaledTensor1x(ScaledTensor):
dq_dtype
:
jnp
.
dtype
_dq_func
:
Callable
is_colwise
:
bool
layout
:
str
data_layout
:
str
flatten_axis
:
int
=
-
1
def
__post_init__
(
self
):
"""Validates and adjusts the scale_inv shape after initialization.
...
...
@@ -117,11 +130,22 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary.
"""
flatten_axis
=
(
len
(
self
.
data
.
shape
)
+
self
.
flatten_axis
if
self
.
flatten_axis
<
0
else
self
.
flatten_axis
)
assert
(
0
<
flatten_axis
<
len
(
self
.
data
.
shape
)
),
f
"flatten_axis
{
flatten_axis
}
is out of bounds for shape
{
self
.
data
.
shape
}
"
if
self
.
data_layout
==
"T"
:
flatten_axis
=
self
.
data
.
ndim
-
flatten_axis
self
.
flatten_axis
=
flatten_axis
expected_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
True
,
flatten_axis
=
flatten_axis
)
expected_unpadded_scale_shape
=
self
.
scaling_mode
.
get_scale_shape
(
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
self
.
data
.
shape
,
self
.
is_colwise
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
if
self
.
scale_inv
.
shape
!=
expected_scale_shape
:
assert
self
.
scale_inv
.
shape
==
expected_unpadded_scale_shape
,
(
...
...
@@ -144,7 +168,14 @@ class ScaledTensor1x(ScaledTensor):
A tuple containing (children, aux_data) for tree operations
"""
children
=
(
self
.
data
,
self
.
scale_inv
)
aux_data
=
(
self
.
scaling_mode
,
self
.
dq_dtype
,
self
.
_dq_func
,
self
.
is_colwise
,
self
.
layout
)
aux_data
=
(
self
.
scaling_mode
,
self
.
dq_dtype
,
self
.
_dq_func
,
self
.
is_colwise
,
self
.
data_layout
,
self
.
flatten_axis
,
)
return
(
children
,
aux_data
)
def
dequantize
(
self
):
...
...
@@ -183,6 +214,45 @@ class ScaledTensor1x(ScaledTensor):
raise
ValueError
(
"Calling get_colwise_tensor() from a rowwise ScaledTensor1x!"
)
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if
not
logical_axis_names
:
return
self
# axis_names were given for N layout, so needs to be transpose for T layout
if
self
.
data_layout
==
"T"
:
assert
self
.
flatten_axis
>
0
flatten_axis
=
-
self
.
flatten_axis
axis_names
=
(
*
logical_axis_names
[
flatten_axis
:],
*
logical_axis_names
[:
flatten_axis
])
else
:
axis_names
=
logical_axis_names
data
=
with_sharding_constraint_by_logical_axes
(
self
.
data
,
axis_names
)
if
self
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
# TODO(Phuong): Handle padding !?
scale_inv
=
with_sharding_constraint_by_logical_axes
(
self
.
scale_inv
,
axis_names
)
else
:
scale_inv
=
self
.
scale_inv
return
ScaledTensor1x
(
data
=
data
,
scale_inv
=
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
self
.
dq_dtype
,
_dq_func
=
self
.
_dq_func
,
is_colwise
=
self
.
is_colwise
,
data_layout
=
self
.
data_layout
,
flatten_axis
=
self
.
flatten_axis
,
)
@
register_pytree_node_class
@
dataclass
...
...
@@ -233,6 +303,27 @@ class ScaledTensor2x(ScaledTensor):
"""
return
self
.
colwise_tensor
def
apply_sharding_constraint_by_logical_axes
(
self
,
logical_axis_names
:
Tuple
[
str
,
...]):
"""Applies sharding constraints to a tensor based on logical axis names.
Args:
logical_axis_names: Tuple of logical axis names for sharding
Returns:
The tensor with applied sharding constraints
"""
if
not
logical_axis_names
:
return
self
rowwise_tensor
=
self
.
rowwise_tensor
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
colwise_tensor
=
self
.
colwise_tensor
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
@
dataclass
class
ScaledTensorFactory
:
...
...
@@ -244,7 +335,13 @@ class ScaledTensorFactory:
@
staticmethod
def
create_1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
is_colwise
=
False
,
layout
=
"N"
data
,
scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
is_colwise
=
False
,
data_layout
=
"N"
,
flatten_axis
=-
1
,
):
"""Creates a single-scale quantized tensor.
...
...
@@ -254,13 +351,16 @@ class ScaledTensorFactory:
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
is_colwise: Whether to use column-wise quantization (default: False)
layout: The layout specification (default: "N")
data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x instance
"""
dq_func
=
Dequantizer
.
funcs
.
get
(
scaling_mode
)
return
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
,
layout
)
return
ScaledTensor1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
dq_func
,
is_colwise
,
data_layout
,
flatten_axis
)
@
staticmethod
def
create_2x
(
...
...
@@ -270,7 +370,8 @@ class ScaledTensorFactory:
colwise_scale_inv
,
scaling_mode
,
dq_dtype
=
jnp
.
bfloat16
,
layout
=
"NN"
,
data_layout
=
"NN"
,
flatten_axis
=-
1
,
):
"""Creates a double-scale quantized tensor.
...
...
@@ -281,7 +382,8 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor2x instance
...
...
@@ -294,7 +396,8 @@ class ScaledTensorFactory:
dq_dtype
,
dq_func
,
is_colwise
=
False
,
layout
=
layout
[
0
],
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
)
colwise_tensor
=
ScaledTensor1x
(
colwise_data
,
...
...
@@ -303,7 +406,8 @@ class ScaledTensorFactory:
dq_dtype
,
dq_func
,
is_colwise
=
True
,
layout
=
layout
[
1
],
data_layout
=
data_layout
[
1
],
flatten_axis
=
flatten_axis
,
)
return
ScaledTensor2x
(
rowwise_tensor
,
colwise_tensor
)
...
...
@@ -315,8 +419,9 @@ class ScaledTensorFactory:
colwise_scale_inv
:
jnp
.
ndarray
,
scaling_mode
:
ScalingMode
,
dq_dtype
:
jnp
.
dtype
=
jnp
.
bfloat16
,
layout
:
str
=
"NN"
,
q_axis
:
QuantizeAxis
=
QuantizeAxis
.
ROWWISE
,
data_layout
:
str
=
"NN"
,
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE
,
flatten_axis
:
int
=
-
1
,
):
"""Creates a scaled tensor based on the quantization axis.
...
...
@@ -327,13 +432,13 @@ class ScaledTensorFactory:
colwise_scale_inv: The column-wise inverse scaling factors
scaling_mode: The scaling mode for quantization
dq_dtype: The data type for dequantized values (default: bfloat16)
layout: The layout specification (default: "NN")
q_
axis
: The quantization axis (default: ROWWISE)
data_
layout: The
data_
layout specification (default: "NN")
q_
layout
: The quantization axis (default: ROWWISE)
Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_
axis
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_
layout
"""
if
q_
axis
==
Quantize
Axis
.
ROWWISE_COLWISE
:
if
q_
layout
==
Quantize
Layout
.
ROWWISE_COLWISE
:
return
ScaledTensorFactory
.
create_2x
(
data
,
scale_inv
,
...
...
@@ -341,12 +446,19 @@ class ScaledTensorFactory:
colwise_scale_inv
,
scaling_mode
,
dq_dtype
,
layout
=
layout
,
data_layout
=
data_layout
,
flatten_axis
=
flatten_axis
,
)
is_colwise
=
q_
axis
==
Quantize
Axis
.
COLWISE
is_colwise
=
q_
layout
==
Quantize
Layout
.
COLWISE
return
ScaledTensorFactory
.
create_1x
(
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
is_colwise
=
is_colwise
,
layout
=
layout
[
0
]
data
,
scale_inv
,
scaling_mode
,
dq_dtype
,
is_colwise
=
is_colwise
,
data_layout
=
data_layout
[
0
],
flatten_axis
=
flatten_axis
,
)
...
...
@@ -360,24 +472,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns:
The tensor with applied sharding constraints
"""
if
isinstance
(
x
,
ScaledTensor1x
):
return
ScaledTensor1x
(
data
=
with_sharding_constraint_by_logical_axes
(
x
.
data
,
logical_axis_names
),
scale_inv
=
x
.
scale_inv
,
scaling_mode
=
x
.
scaling_mode
,
dq_dtype
=
x
.
dq_dtype
,
_dq_func
=
x
.
_dq_func
,
is_colwise
=
x
.
is_colwise
,
layout
=
x
.
layout
,
)
if
isinstance
(
x
,
ScaledTensor2x
):
return
ScaledTensor2x
(
rowwise_tensor
=
with_sharding_constraint_by_logical_axes
(
x
.
rowwise_tensor
,
logical_axis_names
),
colwise_tensor
=
with_sharding_constraint_by_logical_axes
(
x
.
colwise_tensor
,
logical_axis_names
),
)
if
isinstance
(
x
,
ScaledTensor
):
return
x
.
apply_sharding_constraint_by_logical_axes
(
logical_axis_names
)
return
original_with_sharding_constraint_by_logical_axes
(
x
,
logical_axis_names
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment