Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2561 additions
and
566 deletions
+2561
-566
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+1
-0
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+655
-0
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+30
-159
tests/cpp/operator/test_normalization.h
tests/cpp/operator/test_normalization.h
+188
-0
tests/cpp/operator/test_normalization_mxfp8.cu
tests/cpp/operator/test_normalization_mxfp8.cu
+24
-77
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+97
-53
tests/cpp/test_common.h
tests/cpp/test_common.h
+11
-12
tests/jax/pytest.ini
tests/jax/pytest.ini
+2
-0
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+152
-122
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+216
-69
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+7
-1
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+112
-46
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+82
-14
tests/jax/test_layer.py
tests/jax/test_layer.py
+46
-5
tests/jax/test_softmax.py
tests/jax/test_softmax.py
+91
-5
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+273
-1
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+23
-1
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+6
-1
tests/pytorch/references/blockwise_fp8_gemm_reference.py
tests/pytorch/references/blockwise_fp8_gemm_reference.py
+242
-0
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+303
-0
No files found.
tests/cpp/operator/CMakeLists.txt
View file @
ab3e5a92
...
@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
...
@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8.cu
# test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_transpose.cu
test_cast_transpose.cu
test_cast_transpose.cu
...
...
tests/cpp/operator/test_cast_float8blockwise.cu
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using
namespace
transformer_engine
;
using
namespace
test
;
namespace
{
struct
QuantizationOptions
{
bool
force_pow_2_scales
=
false
;
float
amax_epsilon
=
0.0
;
size_t
block_scaling_dim
=
2u
;
};
constexpr
size_t
kBlockLen
=
128
;
enum
ProcessingMethod
{
CAST_ONLY
,
// CAST_DBIAS,
// CAST_DBIAS_DACT,
// CAST_DACT,
// CAST_ACT
};
enum
ActivationType
{
Identity
,
// GeLU,
// SiLU,
// ReLU,
// QGeLU,
// SReLU
};
template
<
typename
InputType
,
typename
OutputType
>
void
scales_from_amax
(
float
amax
,
const
QuantizationOptions
&
opts
,
float
*
qscale_out
,
float
*
qscale_inv_out
)
{
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
float
eps
=
opts
.
amax_epsilon
;
amax
=
std
::
max
(
amax
,
eps
);
float
qscale
=
quant_type_max_val
/
amax
;
if
(
std
::
isinf
(
qscale
))
{
qscale
=
input_type_max_val
;
}
if
(
std
::
isnan
(
qscale
)
||
amax
==
0
)
{
qscale
=
1.0
;
}
if
(
opts
.
force_pow_2_scales
&&
qscale
!=
0.0
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
qscale
);
// Scale must be positive, shift it
uint8_t
exp
=
scale_bits
>>
23
;
ASSERT_FALSE
(
exp
==
0
)
<<
"Subnormals in this path is a logic error."
;
qscale
=
ldexpf
(
1.0
f
,
static_cast
<
int32_t
>
(
exp
)
-
127
);
}
float
qscale_inv
=
1.0
/
qscale
;
*
qscale_out
=
qscale
;
*
qscale_inv_out
=
qscale_inv
;
}
template
<
typename
InputType
,
typename
OutputType
>
void
ref_quantize
(
const
ProcessingMethod
processing_method
,
const
InputType
*
input
,
const
std
::
pair
<
size_t
,
size_t
>&
input_hw
,
OutputType
*
output
,
float
*
scale_inv
,
OutputType
*
output_t
,
float
*
scale_inv_t
,
const
QuantizationOptions
&
opts
)
{
constexpr
size_t
kBlockLenX
=
kBlockLen
;
constexpr
size_t
kBlockLenY
=
kBlockLen
;
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
// Scale in FP32 and cast result to nearest FP8.
return
static_cast
<
OutputType
>
(
float
(
element
)
*
qscale
);
};
size_t
height
=
input_hw
.
first
;
size_t
width
=
input_hw
.
second
;
size_t
blocks_x
=
(
width
+
kBlockLenX
-
1
)
/
kBlockLenX
;
size_t
blocks_y
=
(
height
+
kBlockLenY
-
1
)
/
kBlockLenY
;
// Find the absolute maximum value in the block
for
(
size_t
block_x
=
0
;
block_x
<
blocks_x
;
++
block_x
)
{
for
(
size_t
block_y
=
0
;
block_y
<
blocks_y
;
++
block_y
)
{
float
amax
=
0.0
f
;
// Calculate amax for a tile.
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
kBlockLenY
;
++
j
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
size_t
y_pos
=
j
+
block_y
*
kBlockLenY
;
if
(
y_pos
>=
height
||
x_pos
>=
width
)
{
continue
;
}
float
val
=
static_cast
<
float
>
(
input
[
y_pos
*
width
+
x_pos
]);
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
}
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float
qscale
,
qscale_inv
;
scales_from_amax
<
InputType
,
OutputType
>
(
amax
,
opts
,
&
qscale
,
&
qscale_inv
);
// NOTE: This reference function outputs contigous scale tensors.
// It calculates a naive scale data format. Strides are handled
// in comparison.
if
(
scale_inv
!=
nullptr
)
{
scale_inv
[
block_y
*
blocks_x
+
block_x
]
=
qscale_inv
;
}
if
(
scale_inv_t
!=
nullptr
)
{
scale_inv_t
[
block_x
*
blocks_y
+
block_y
]
=
qscale_inv
;
}
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
kBlockLenY
;
++
j
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
size_t
y_pos
=
j
+
block_y
*
kBlockLenY
;
if
(
y_pos
>=
height
||
x_pos
>=
width
)
{
continue
;
}
if
(
output
!=
nullptr
)
{
output
[
y_pos
*
width
+
x_pos
]
=
quantize_element
(
input
[
y_pos
*
width
+
x_pos
],
qscale
);
}
if
(
output_t
!=
nullptr
)
{
output_t
[
x_pos
*
height
+
y_pos
]
=
quantize_element
(
input
[
y_pos
*
width
+
x_pos
],
qscale
);
}
}
}
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
ref_quantize_onedimensional_blocks
(
const
ProcessingMethod
processing_method
,
const
InputType
*
input
,
const
std
::
pair
<
size_t
,
size_t
>&
input_hw
,
OutputType
*
output
,
float
*
scale_inv
,
OutputType
*
output_t
,
float
*
scale_inv_t
,
const
QuantizationOptions
&
opts
)
{
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
constexpr
size_t
kBlockLenX
=
kBlockLen
;
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
// Scale in FP32 and cast result to nearest FP8.
return
static_cast
<
OutputType
>
(
float
(
element
)
*
qscale
);
};
size_t
height
=
input_hw
.
first
;
size_t
width
=
input_hw
.
second
;
size_t
blocks_x
=
(
width
+
kBlockLenX
-
1
)
/
kBlockLenX
;
size_t
blocks_x_t
=
(
height
+
kBlockLenX
-
1
)
/
kBlockLenX
;
if
(
output
!=
nullptr
&&
scale_inv
!=
nullptr
)
{
// Find the absolute maximum value in the block
for
(
size_t
block_x
=
0
;
block_x
<
blocks_x
;
++
block_x
)
{
for
(
size_t
y
=
0
;
y
<
height
;
++
y
)
{
float
amax
=
0.0
f
;
// Calculate amax for a tile.
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
if
(
x_pos
>=
width
)
{
continue
;
}
float
val
=
static_cast
<
float
>
(
input
[
y
*
width
+
x_pos
]);
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float
qscale
,
qscale_inv
;
scales_from_amax
<
InputType
,
OutputType
>
(
amax
,
opts
,
&
qscale
,
&
qscale_inv
);
scale_inv
[
y
+
height
*
block_x
]
=
qscale_inv
;
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
if
(
x_pos
>=
width
)
{
continue
;
}
output
[
y
*
width
+
x_pos
]
=
quantize_element
(
input
[
y
*
width
+
x_pos
],
qscale
);
}
}
}
}
if
(
output_t
!=
nullptr
&&
scale_inv_t
!=
nullptr
)
{
// Find the absolute maximum value in the block
for
(
size_t
block_x_t
=
0
;
block_x_t
<
blocks_x_t
;
++
block_x_t
)
{
for
(
size_t
x
=
0
;
x
<
width
;
++
x
)
{
float
amax
=
0.0
f
;
// Calculate amax for a tile.
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
y_pos
=
i
+
block_x_t
*
kBlockLenX
;
if
(
y_pos
>=
height
)
{
continue
;
}
float
val
=
static_cast
<
float
>
(
input
[
x
+
y_pos
*
width
]);
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float
qscale
,
qscale_inv
;
scales_from_amax
<
InputType
,
OutputType
>
(
amax
,
opts
,
&
qscale
,
&
qscale_inv
);
scale_inv_t
[
x
+
width
*
block_x_t
]
=
qscale_inv
;
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
y_pos
=
i
+
block_x_t
*
kBlockLenX
;
if
(
y_pos
>=
height
)
{
continue
;
}
output_t
[
x
*
height
+
y_pos
]
=
quantize_element
(
input
[
y_pos
*
width
+
x
],
qscale
);
}
}
}
}
}
inline
size_t
scale_align_stride
(
size_t
inner_elements
)
{
return
((
inner_elements
+
4u
-
1u
)
/
4u
)
*
4u
;
};
void
compare_scaling_factors
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
row_blocks
,
const
size_t
col_blocks
,
const
size_t
test_stride
,
const
size_t
ref_stride
)
{
for
(
int
i
=
0
;
i
<
row_blocks
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
*
test_stride
+
j
;
const
int
ref_idx
=
i
*
ref_stride
+
j
;
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
}
}
}
void
compare_scaling_factors_one_dimensional_blocks
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
rows
,
const
size_t
col_blocks
)
{
const
size_t
test_stride
=
scale_align_stride
(
rows
);
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
+
test_stride
*
j
;
const
int
ref_idx
=
i
+
rows
*
j
;
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
runTestCase
(
const
ProcessingMethod
processing_method
,
const
std
::
vector
<
size_t
>&
shape
,
const
bool
rowwise
,
const
bool
colwise
,
InputsFillCase
fill_case
,
const
QuantizationOptions
&
opts
)
{
using
namespace
test
;
using
EncodingType
=
fp32
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
const
size_t
rows
=
first_dimension
(
shape
);
const
size_t
cols
=
last_dimension
(
shape
);
size_t
blocks_x
=
(
cols
+
kBlockLen
-
1
)
/
kBlockLen
;
size_t
blocks_y
=
(
rows
+
kBlockLen
-
1
)
/
kBlockLen
;
Tensor
input
(
"input"
,
shape
,
itype
);
Tensor
grad
(
"grad"
,
shape
,
itype
);
Tensor
output_c
(
"output_c"
,
shape
,
otype
,
rowwise
,
colwise
,
opts
.
block_scaling_dim
==
2
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
Tensor
output_dbias
(
"output_dbias"
,
{
cols
},
itype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output_t
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv
=
std
::
make_unique
<
float
[]
>
(
blocks_y
*
blocks_x
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv_t
=
std
::
make_unique
<
float
[]
>
(
blocks_y
*
blocks_x
);
if
(
!
rowwise
)
{
ref_output
=
nullptr
;
ref_scale_inv
=
nullptr
;
}
if
(
!
colwise
)
{
ref_output_t
=
nullptr
;
ref_scale_inv_t
=
nullptr
;
}
fillCase
<
EncodingType
>
(
&
input
,
fill_case
);
fillUniform
(
&
grad
);
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
opts
.
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
opts
.
amax_epsilon
);
Tensor
workspace
;
switch
(
processing_method
)
{
case
ProcessingMethod
::
CAST_ONLY
:
{
nvte_quantize_v2
(
input
.
data
(),
output_c
.
data
(),
quant_config
,
nullptr
);
break
;
}
}
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
ref_quantize
<
InputType
,
OutputType
>
(
processing_method
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
{
rows
,
cols
},
ref_output
.
get
(),
ref_scale_inv
.
get
(),
ref_output_t
.
get
(),
ref_scale_inv_t
.
get
(),
opts
);
float
atol
=
0.0
;
float
rtol
=
0.0
;
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
blocks_y
,
blocks_x
,
scale_align_stride
(
blocks_x
),
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
blocks_x
,
blocks_y
,
scale_align_stride
(
blocks_y
),
blocks_y
);
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
runTestCaseOneDimensionalBlocks
(
const
ProcessingMethod
processing_method
,
const
std
::
vector
<
size_t
>&
shape
,
const
bool
rowwise
,
const
bool
colwise
,
InputsFillCase
fill_case
,
const
QuantizationOptions
&
opts
)
{
using
namespace
test
;
using
EncodingType
=
fp32
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
const
size_t
rows
=
first_dimension
(
shape
);
const
size_t
cols
=
last_dimension
(
shape
);
size_t
blocks_x
=
(
cols
+
kBlockLen
-
1
)
/
kBlockLen
;
size_t
blocks_x_t
=
(
rows
+
kBlockLen
-
1
)
/
kBlockLen
;
Tensor
input
(
"input"
,
shape
,
itype
);
Tensor
grad
(
"grad"
,
shape
,
itype
);
Tensor
output_c
(
"output_c"
,
shape
,
otype
,
rowwise
,
colwise
,
opts
.
block_scaling_dim
==
2
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
Tensor
output_dbias
(
"output_dbias"
,
{
cols
},
itype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output_t
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv
=
std
::
make_unique
<
float
[]
>
(
rows
*
blocks_x
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv_t
=
std
::
make_unique
<
float
[]
>
(
cols
*
blocks_x_t
);
if
(
!
rowwise
)
{
ref_output
=
nullptr
;
ref_scale_inv
=
nullptr
;
}
if
(
!
colwise
)
{
ref_output_t
=
nullptr
;
ref_scale_inv_t
=
nullptr
;
}
fillCase
<
EncodingType
>
(
&
input
,
fill_case
);
fillUniform
(
&
grad
);
Tensor
workspace
;
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
opts
.
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
opts
.
amax_epsilon
);
switch
(
processing_method
)
{
case
ProcessingMethod
::
CAST_ONLY
:
{
nvte_quantize_v2
(
input
.
data
(),
output_c
.
data
(),
quant_config
,
nullptr
);
break
;
}
}
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
ref_quantize_onedimensional_blocks
<
InputType
,
OutputType
>
(
processing_method
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
{
rows
,
cols
},
ref_output
.
get
(),
ref_scale_inv
.
get
(),
ref_output_t
.
get
(),
ref_scale_inv_t
.
get
(),
opts
);
float
atol
=
0.0
;
float
rtol
=
0.0
;
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
rows
,
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
);
}
}
std
::
vector
<
std
::
vector
<
size_t
>>
matrix_sizes
=
{
{
1
,
16
},
{
65
,
96
},
{
256
,
256
},
{
993
,
512
},
{
256
,
65536
},
{
4096
,
1632
},
{
1024
,
1
},
{
16
,
512
},
{
1024
},
{
8
,
32
,
1024
},
{
16
,
8
,
4
,
512
},
};
std
::
vector
<
InputsFillCase
>
input_scenarios
=
{
InputsFillCase
::
uniform
,
};
std
::
vector
<
ProcessingMethod
>
processing_methods
=
{
ProcessingMethod
::
CAST_ONLY
,
// ProcessingMethod::CAST_DBIAS,
// ProcessingMethod::CAST_DBIAS_DACT,
// ProcessingMethod::CAST_DACT,
// ProcessingMethod::CAST_ACT,
};
// Only GeLU activation tests are supported
std
::
vector
<
ActivationType
>
Activation_types
=
{
ActivationType
::
Identity
,
// ActivationType::GeLU,
// ActivationType::SiLU,
// ActivationType::ReLU,
// ActivationType::QGeLU,
// ActivationType::SReLU,
};
std
::
vector
<
float
>
amax_epsilons
=
{
0.0
f
,
1.0
f
,
// Make large to be observable.
};
}
// namespace
class
FusedCastFloat8BlockwiseTestSuite
:
public
::
testing
::
TestWithParam
<
std
::
tuple
<
ProcessingMethod
,
ActivationType
,
std
::
vector
<
size_t
>
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
InputsFillCase
,
bool
,
float
,
bool
>>
{};
class
FusedCastFloat8VectorwiseTestSuite
:
public
::
testing
::
TestWithParam
<
std
::
tuple
<
ProcessingMethod
,
ActivationType
,
std
::
vector
<
size_t
>
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
InputsFillCase
,
bool
,
float
,
bool
>>
{};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { \
constexpr auto OP = &identity; \
{ \
__VA_ARGS__ \
} \
} break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { \
constexpr auto OP = &identity; \
{ \
__VA_ARGS__ \
} \
} break; \
}
TEST_P
(
FusedCastFloat8BlockwiseTestSuite
,
TestFusedCastFloat8Blockwise
)
{
if
(
getDeviceComputeCapability
()
<
hopperComputeCapability
)
{
GTEST_SKIP
();
}
using
namespace
transformer_engine
;
using
namespace
test
;
const
ProcessingMethod
processing_method
=
std
::
get
<
0
>
(
GetParam
());
const
ActivationType
Act_type
=
std
::
get
<
1
>
(
GetParam
());
const
auto
matrix_size
=
std
::
get
<
2
>
(
GetParam
());
const
DType
input_type
=
std
::
get
<
3
>
(
GetParam
());
const
DType
output_type
=
std
::
get
<
4
>
(
GetParam
());
const
InputsFillCase
fill_case
=
std
::
get
<
5
>
(
GetParam
());
const
bool
colwise
=
std
::
get
<
6
>
(
GetParam
());
const
bool
rowwise
=
true
;
const
float
eps
=
std
::
get
<
7
>
(
GetParam
());
const
bool
force_pow_2
=
std
::
get
<
8
>
(
GetParam
());
QuantizationOptions
q_opts
;
q_opts
.
force_pow_2_scales
=
force_pow_2
;
q_opts
.
amax_epsilon
=
eps
;
q_opts
.
block_scaling_dim
=
2u
;
if
(
colwise
&&
matrix_size
.
size
()
<
2
)
{
// test_common Tensor initialization code does not
// handle this case.
GTEST_SKIP
();
}
// Skips non Act tests if the Activation type is not an identity
if
(
// (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
(
processing_method
==
ProcessingMethod
::
CAST_ONLY
)
&&
Act_type
!=
ActivationType
::
Identity
)
{
GTEST_SKIP
();
}
// Skips Act tests if the Activation is an identity
// if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
// || processing_method == ProcessingMethod::CAST_DACT
// || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
// GTEST_SKIP();
// }
DACT_FUNC_SWITCH
(
Act_type
,
OP
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
runTestCase
<
InputType
,
OutputType
>
(
processing_method
,
matrix_size
,
rowwise
,
colwise
,
fill_case
,
q_opts
););););
}
TEST_P
(
FusedCastFloat8VectorwiseTestSuite
,
TestFusedCastFloat8Vectorwise
)
{
if
(
getDeviceComputeCapability
()
<
hopperComputeCapability
)
{
GTEST_SKIP
();
}
using
namespace
transformer_engine
;
using
namespace
test
;
const
ProcessingMethod
processing_method
=
std
::
get
<
0
>
(
GetParam
());
const
ActivationType
Act_type
=
std
::
get
<
1
>
(
GetParam
());
const
auto
matrix_size
=
std
::
get
<
2
>
(
GetParam
());
const
DType
input_type
=
std
::
get
<
3
>
(
GetParam
());
const
DType
output_type
=
std
::
get
<
4
>
(
GetParam
());
const
InputsFillCase
fill_case
=
std
::
get
<
5
>
(
GetParam
());
const
bool
colwise
=
std
::
get
<
6
>
(
GetParam
());
const
bool
rowwise
=
true
;
const
float
eps
=
std
::
get
<
7
>
(
GetParam
());
const
bool
force_pow_2
=
std
::
get
<
8
>
(
GetParam
());
QuantizationOptions
q_opts
;
q_opts
.
force_pow_2_scales
=
force_pow_2
;
q_opts
.
amax_epsilon
=
eps
;
q_opts
.
block_scaling_dim
=
1u
;
if
(
colwise
&&
matrix_size
.
size
()
<
2
)
{
// test_common Tensor initialization code does not
// handle this case.
GTEST_SKIP
();
}
// Skips non Act tests if the Activation type is not an identity
if
(
// (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
(
processing_method
==
ProcessingMethod
::
CAST_ONLY
)
&&
Act_type
!=
ActivationType
::
Identity
)
{
GTEST_SKIP
();
}
// Skips Act tests if the Activation is an identity
// if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
// || processing_method == ProcessingMethod::CAST_DACT
// || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
// GTEST_SKIP();
// }
DACT_FUNC_SWITCH
(
Act_type
,
OP
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
runTestCaseOneDimensionalBlocks
<
InputType
,
OutputType
>
(
processing_method
,
matrix_size
,
rowwise
,
colwise
,
fill_case
,
q_opts
););););
}
std
::
string
to_string
(
const
ProcessingMethod
method
)
{
switch
(
method
)
{
case
ProcessingMethod
::
CAST_ONLY
:
return
"CAST_ONLY"
;
// case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
// case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
// case ProcessingMethod::CAST_DACT: return "CAST_DACT";
// case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default:
return
""
;
}
}
std
::
string
to_string
(
const
ActivationType
Act_type
)
{
switch
(
Act_type
)
{
case
ActivationType
::
Identity
:
return
"Identity"
;
// case ActivationType::GeLU: return "GeLU";
// case ActivationType::SiLU: return "SiLU";
// case ActivationType::ReLU: return "ReLU";
// case ActivationType::QGeLU: return "QGeLU";
// case ActivationType::SReLU: return "SReLU";
default:
return
""
;
}
}
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
FusedCastFloat8BlockwiseTestSuite
,
::
testing
::
Combine
(
::
testing
::
ValuesIn
(
processing_methods
),
::
testing
::
ValuesIn
(
Activation_types
),
::
testing
::
ValuesIn
(
matrix_sizes
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
),
::
testing
::
ValuesIn
(
input_scenarios
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
ValuesIn
(
amax_epsilons
),
::
testing
::
Values
(
true
,
false
)),
[](
const
testing
::
TestParamInfo
<
FusedCastFloat8BlockwiseTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
to_string
(
std
::
get
<
0
>
(
info
.
param
))
+
"X"
+
to_string
(
std
::
get
<
1
>
(
info
.
param
));
const
auto
&
shape
=
std
::
get
<
2
>
(
info
.
param
);
for
(
const
auto
&
s
:
shape
)
{
name
+=
"X"
+
std
::
to_string
(
s
);
}
name
+=
"X"
+
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
test
::
typeName
(
std
::
get
<
4
>
(
info
.
param
))
+
"X"
+
test
::
caseName
(
std
::
get
<
5
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
7
>
(
info
.
param
)
!=
0.0
f
)
+
"X"
+
std
::
to_string
(
std
::
get
<
8
>
(
info
.
param
));
return
name
;
});
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
FusedCastFloat8VectorwiseTestSuite
,
::
testing
::
Combine
(
::
testing
::
ValuesIn
(
processing_methods
),
::
testing
::
ValuesIn
(
Activation_types
),
::
testing
::
ValuesIn
(
matrix_sizes
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
),
::
testing
::
ValuesIn
(
input_scenarios
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
ValuesIn
(
amax_epsilons
),
::
testing
::
Values
(
true
,
false
)),
[](
const
testing
::
TestParamInfo
<
FusedCastFloat8VectorwiseTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
to_string
(
std
::
get
<
0
>
(
info
.
param
))
+
"X"
+
to_string
(
std
::
get
<
1
>
(
info
.
param
));
const
auto
&
shape
=
std
::
get
<
2
>
(
info
.
param
);
for
(
const
auto
&
s
:
shape
)
{
name
+=
"X"
+
std
::
to_string
(
s
);
}
name
+=
"X"
+
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
test
::
typeName
(
std
::
get
<
4
>
(
info
.
param
))
+
"X"
+
test
::
caseName
(
std
::
get
<
5
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
7
>
(
info
.
param
)
!=
0.0
f
)
+
"X"
+
std
::
to_string
(
std
::
get
<
8
>
(
info
.
param
));
return
name
;
});
tests/cpp/operator/test_normalization.cu
View file @
ab3e5a92
...
@@ -19,169 +19,16 @@
...
@@ -19,169 +19,16 @@
#include <transformer_engine/normalization.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
#include "../test_common.h"
#include "test_normalization.h"
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
using
namespace
test
;
using
namespace
test
;
namespace
{
namespace
{
enum
NormType
{
LayerNorm
,
RMSNorm
};
std
::
map
<
NormType
,
std
::
string
>
normToString
=
{
{
NormType
::
LayerNorm
,
"LayerNorm"
},
{
NormType
::
RMSNorm
,
"RmsNorm"
}
};
template
<
typename
InputType
>
void
compute_ref_stats
(
NormType
norm_type
,
const
InputType
*
data
,
float
*
mu
,
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
const
double
epsilon
){
using
compute_t
=
float
;
compute_t
current
,
m
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
sum
+=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
}
if
(
norm_type
==
LayerNorm
){
mu
[
i
]
=
sum
/
H
;
m
=
mu
[
i
];
}
else
{
m
=
0
;}
compute_t
sum_sq
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template
<
typename
InputType
>
inline
auto
compute_gamma
(
InputType
gamma
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
){
using
compute_t
=
float
;
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
){
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
else
{
if
(
use_cudnn
){
compute_t
g
=
static_cast
<
compute_t
>
(
0.
f
);
#ifndef __HIP_PLATFORM_AMD__
InputType
gi
=
gamma
;
if
(
zero_centered_gamma
)
{
gi
=
gi
+
static_cast
<
InputType
>
(
1.
f
);
}
g
=
static_cast
<
compute_t
>
(
gi
);
#else
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
#endif
return
g
;
}
else
{
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_output
(
NormType
norm_type
,
const
InputType
*
data
,
const
InputType
*
gamma
,
const
InputType
*
beta
,
OutputType
*
output
,
const
float
*
mu
,
const
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
float
*
amax
,
float
scale
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
)
{
using
compute_t
=
float
;
compute_t
current_max
=
-
1e100
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
);
compute_t
tmp
;
if
(
norm_type
==
LayerNorm
)
{
tmp
=
(
current
-
mu
[
i
])
*
rsigma
[
i
]
*
g
+
static_cast
<
compute_t
>
(
beta
[
j
]);
}
else
{
// RMSNorm
tmp
=
current
*
rsigma
[
i
]
*
g
;
}
output
[
i
*
H
+
j
]
=
static_cast
<
OutputType
>
(
tmp
*
scale
);
current_max
=
fmaxf
(
current_max
,
fabsf
(
tmp
));
}
}
*
amax
=
current_max
;
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_backward
(
const
NormType
norm_type
,
const
OutputType
*
output_grad
,
const
InputType
*
data
,
const
float
*
mu
,
const
float
*
rsigma
,
const
InputType
*
gamma
,
InputType
*
data_grad
,
InputType
*
gamma_grad
,
InputType
*
beta_grad
,
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
)
{
using
compute_t
=
float
;
std
::
vector
<
compute_t
>
dgamma
(
H
,
0.
f
);
std
::
vector
<
compute_t
>
dbeta
(
H
,
0.
f
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
// Reductions
auto
local_mu
=
(
norm_type
==
LayerNorm
)
?
mu
[
i
]
:
0.
;
compute_t
mdy
=
0
,
mdyy
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
dgamma
[
j
]
+=
y
*
dz
;
if
(
norm_type
==
LayerNorm
)
{
dbeta
[
j
]
+=
dz
;
mdy
+=
dy
;
}
mdyy
+=
dy
*
y
;
}
mdy
/=
H
;
mdyy
/=
H
;
// Input grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
const
compute_t
dx
=
rsigma
[
i
]
*
(
dy
-
mdyy
*
y
-
mdy
);
data_grad
[
i
*
H
+
j
]
=
static_cast
<
InputType
>
(
dx
);
}
}
// Weight grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
gamma_grad
[
j
]
=
static_cast
<
InputType
>
(
dgamma
[
j
]);
if
(
norm_type
==
LayerNorm
)
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
beta_grad
[
j
]
=
static_cast
<
InputType
>
(
dbeta
[
j
]);
}
template
<
typename
InputType
,
typename
OutputType
>
template
<
typename
InputType
,
typename
OutputType
>
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
NormType
norm_type
,
bool
use_cudnn
)
{
NormType
norm_type
,
bool
use_cudnn
,
const
bool
zero_centered_gamma_in_weight_dtype
)
{
if
(
sizeof
(
InputType
)
<
sizeof
(
OutputType
))
{
if
(
sizeof
(
InputType
)
<
sizeof
(
OutputType
))
{
GTEST_SKIP
()
<<
"LN kernel does not support OutputType > InputType"
;
GTEST_SKIP
()
<<
"LN kernel does not support OutputType > InputType"
;
return
;
return
;
...
@@ -230,9 +77,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -230,9 +77,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
cudaDeviceProp
prop
;
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
0
);
cudaGetDeviceProperties
(
&
prop
,
0
);
if
((
!
use_cudnn
||
!
zero_centered_gamma
)
&&
zero_centered_gamma_in_weight_dtype
)
{
// Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
GTEST_SKIP
()
<<
"Zero-centered gamma in weight dtype is only supported with cuDNN backend"
;
}
if
(
use_cudnn
){
if
(
use_cudnn
){
nvte_enable_cudnn_norm_fwd
(
true
);
nvte_enable_cudnn_norm_fwd
(
true
);
nvte_enable_cudnn_norm_bwd
(
true
);
nvte_enable_cudnn_norm_bwd
(
true
);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
true
);
}
else
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
}
}
// Forward kernel
// Forward kernel
...
@@ -280,6 +140,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -280,6 +140,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if
(
use_cudnn
){
if
(
use_cudnn
){
nvte_enable_cudnn_norm_fwd
(
false
);
nvte_enable_cudnn_norm_fwd
(
false
);
nvte_enable_cudnn_norm_bwd
(
false
);
nvte_enable_cudnn_norm_bwd
(
false
);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
}
}
// Reference implementations
// Reference implementations
...
@@ -300,14 +165,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -300,14 +165,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
&
ref_amax
,
&
ref_amax
,
ref_scale
,
ref_scale
,
zero_centered_gamma
,
zero_centered_gamma
,
use_cudnn
);
use_cudnn
,
zero_centered_gamma_in_weight_dtype
);
compute_ref_backward
(
norm_type
,
dz
.
rowwise_cpu_dptr
<
WeightType
>
(),
compute_ref_backward
(
norm_type
,
dz
.
rowwise_cpu_dptr
<
WeightType
>
(),
input
.
rowwise_cpu_dptr
<
InputType
>
(),
input
.
rowwise_cpu_dptr
<
InputType
>
(),
mu
.
rowwise_cpu_dptr
<
float
>
(),
rsigma
.
rowwise_cpu_dptr
<
float
>
(),
mu
.
rowwise_cpu_dptr
<
float
>
(),
rsigma
.
rowwise_cpu_dptr
<
float
>
(),
gamma
.
rowwise_cpu_dptr
<
WeightType
>
(),
gamma
.
rowwise_cpu_dptr
<
WeightType
>
(),
ref_dx
.
get
(),
ref_dgamma
.
get
(),
ref_dbeta
.
get
(),
ref_dx
.
get
(),
ref_dgamma
.
get
(),
ref_dbeta
.
get
(),
N
,
H
,
zero_centered_gamma
,
N
,
H
,
zero_centered_gamma
,
use_cudnn
);
use_cudnn
,
zero_centered_gamma_in_weight_dtype
);
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
auto
err
=
cudaGetLastError
();
...
@@ -352,6 +219,7 @@ NormType,
...
@@ -352,6 +219,7 @@ NormType,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
std
::
pair
<
size_t
,
size_t
>
,
std
::
pair
<
size_t
,
size_t
>
,
bool
,
bool
>>
{};
bool
>>
{};
TEST_P
(
NormTestSuite
,
TestNorm
)
{
TEST_P
(
NormTestSuite
,
TestNorm
)
{
...
@@ -364,10 +232,11 @@ TEST_P(NormTestSuite, TestNorm) {
...
@@ -364,10 +232,11 @@ TEST_P(NormTestSuite, TestNorm) {
const
DType
output_type
=
std
::
get
<
3
>
(
GetParam
());
const
DType
output_type
=
std
::
get
<
3
>
(
GetParam
());
const
auto
size
=
std
::
get
<
4
>
(
GetParam
());
const
auto
size
=
std
::
get
<
4
>
(
GetParam
());
const
bool
zero_centered_gamma
=
std
::
get
<
5
>
(
GetParam
());
const
bool
zero_centered_gamma
=
std
::
get
<
5
>
(
GetParam
());
const
bool
cudnn_zero_centered_gamm_in_weight_dtype
=
std
::
get
<
6
>
(
GetParam
());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
output_type
,
OutputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
output_type
,
OutputType
,
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
use_cudnn
);
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
use_cudnn
,
cudnn_zero_centered_gamm_in_weight_dtype
);
);
);
);
);
}
}
...
@@ -381,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(
...
@@ -381,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
,
DType
::
kFloat8E4M3
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
,
DType
::
kFloat8E4M3
),
::
testing
::
ValuesIn
(
test_cases
),
::
testing
::
ValuesIn
(
test_cases
),
::
testing
::
Values
(
false
,
true
),
::
testing
::
Values
(
false
,
true
)),
::
testing
::
Values
(
false
,
true
)),
[](
const
testing
::
TestParamInfo
<
NormTestSuite
::
ParamType
>&
info
)
{
[](
const
testing
::
TestParamInfo
<
NormTestSuite
::
ParamType
>&
info
)
{
auto
backend
=
std
::
get
<
0
>
(
info
.
param
)
==
false
?
"Te"
:
"Cudnn"
;
auto
backend
=
std
::
get
<
0
>
(
info
.
param
)
==
false
?
"Te"
:
"Cudnn"
;
...
@@ -391,6 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
...
@@ -391,6 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
).
first
)
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
).
first
)
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
).
second
)
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
).
second
)
+
"X"
+
std
::
to_string
(
std
::
get
<
5
>
(
info
.
param
));
std
::
to_string
(
std
::
get
<
5
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
));
return
name
;
return
name
;
});
});
tests/cpp/operator/test_normalization.h
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
namespace
test
{
namespace
{
enum
NormType
{
LayerNorm
,
RMSNorm
};
std
::
map
<
NormType
,
std
::
string
>
normToString
=
{
{
NormType
::
LayerNorm
,
"LayerNorm"
},
{
NormType
::
RMSNorm
,
"RmsNorm"
}
};
template
<
typename
InputType
>
void
compute_ref_stats
(
NormType
norm_type
,
const
InputType
*
data
,
float
*
mu
,
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
const
double
epsilon
){
using
compute_t
=
float
;
compute_t
current
,
m
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
sum
+=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
}
if
(
norm_type
==
LayerNorm
){
mu
[
i
]
=
sum
/
H
;
m
=
mu
[
i
];
}
else
{
m
=
0
;}
compute_t
sum_sq
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
template
<
typename
InputType
>
inline
auto
compute_gamma
(
InputType
gamma
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
,
const
bool
cudnn_zero_centered_gamma_in_weight_dtype
)
{
using
compute_t
=
float
;
// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const
bool
zero_centered_gamma_in_weight_dtype
=
use_cudnn
&&
cudnn_zero_centered_gamma_in_weight_dtype
;
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
){
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
else
{
if
(
zero_centered_gamma_in_weight_dtype
){
compute_t
g
=
static_cast
<
compute_t
>
(
0.
f
);
#ifndef __HIP_PLATFORM_AMD__
InputType
gi
=
gamma
;
if
(
zero_centered_gamma
)
{
gi
=
gi
+
static_cast
<
InputType
>
(
1.
f
);
}
g
=
static_cast
<
compute_t
>
(
gi
);
#else
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
#endif
return
g
;
}
else
{
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_output
(
NormType
norm_type
,
const
InputType
*
data
,
const
InputType
*
gamma
,
const
InputType
*
beta
,
OutputType
*
output
,
const
float
*
mu
,
const
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
float
*
amax
,
float
scale
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
,
const
bool
cudnn_zero_centered_gamma_in_weight_dtype
)
{
using
compute_t
=
float
;
compute_t
current_max
=
-
1e100
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
,
cudnn_zero_centered_gamma_in_weight_dtype
);
compute_t
tmp
;
if
(
norm_type
==
LayerNorm
)
{
tmp
=
(
current
-
mu
[
i
])
*
rsigma
[
i
]
*
g
+
static_cast
<
compute_t
>
(
beta
[
j
]);
}
else
{
// RMSNorm
tmp
=
current
*
rsigma
[
i
]
*
g
;
}
output
[
i
*
H
+
j
]
=
static_cast
<
OutputType
>
(
tmp
*
scale
);
current_max
=
fmaxf
(
current_max
,
fabsf
(
tmp
));
}
}
if
(
amax
)
{
*
amax
=
current_max
;
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_backward
(
const
NormType
norm_type
,
const
OutputType
*
output_grad
,
const
InputType
*
data
,
const
float
*
mu
,
const
float
*
rsigma
,
const
InputType
*
gamma
,
InputType
*
data_grad
,
InputType
*
gamma_grad
,
InputType
*
beta_grad
,
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
,
const
bool
cudnn_zero_centered_gamma_in_weight_dtype
)
{
using
compute_t
=
float
;
std
::
vector
<
compute_t
>
dgamma
(
H
,
0.
f
);
std
::
vector
<
compute_t
>
dbeta
(
H
,
0.
f
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
// Reductions
auto
local_mu
=
(
norm_type
==
LayerNorm
)
?
mu
[
i
]
:
0.
;
compute_t
mdy
=
0
,
mdyy
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
,
cudnn_zero_centered_gamma_in_weight_dtype
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
dgamma
[
j
]
+=
y
*
dz
;
if
(
norm_type
==
LayerNorm
)
{
dbeta
[
j
]
+=
dz
;
mdy
+=
dy
;
}
mdyy
+=
dy
*
y
;
}
mdy
/=
H
;
mdyy
/=
H
;
// Input grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
,
cudnn_zero_centered_gamma_in_weight_dtype
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
const
compute_t
dx
=
rsigma
[
i
]
*
(
dy
-
mdyy
*
y
-
mdy
);
data_grad
[
i
*
H
+
j
]
=
static_cast
<
InputType
>
(
dx
);
}
}
// Weight grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
gamma_grad
[
j
]
=
static_cast
<
InputType
>
(
dgamma
[
j
]);
if
(
norm_type
==
LayerNorm
)
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
beta_grad
[
j
]
=
static_cast
<
InputType
>
(
dbeta
[
j
]);
}
}
// namespace
}
// namespace test
tests/cpp/operator/test_normalization_mxfp8.cu
View file @
ab3e5a92
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <transformer_engine/normalization.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
#include "../test_common.h"
#include "test_normalization.h"
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
using
namespace
test
;
using
namespace
test
;
...
@@ -27,16 +28,6 @@ namespace {
...
@@ -27,16 +28,6 @@ namespace {
using
fp8e8m0
=
byte
;
using
fp8e8m0
=
byte
;
enum
NormType
{
LayerNorm
,
RMSNorm
};
std
::
map
<
NormType
,
std
::
string
>
normToString
=
{
{
NormType
::
LayerNorm
,
"LayerNorm"
},
{
NormType
::
RMSNorm
,
"RMSNorm"
}
};
template
<
typename
InputType
,
typename
ScaleType
,
typename
OutputType
>
template
<
typename
InputType
,
typename
ScaleType
,
typename
OutputType
>
void
dequantize_1x_kernel
(
InputType
*
input_ptr
,
ScaleType
*
scale_ptr
,
OutputType
*
output_ptr
,
void
dequantize_1x_kernel
(
InputType
*
input_ptr
,
ScaleType
*
scale_ptr
,
OutputType
*
output_ptr
,
size_t
rows
,
size_t
cols
,
size_t
scaling_mode_x
,
size_t
scaling_mode_y
){
size_t
rows
,
size_t
cols
,
size_t
scaling_mode_x
,
size_t
scaling_mode_y
){
...
@@ -110,69 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
...
@@ -110,69 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
32
,
1
);
32
,
1
);
}
}
template
<
typename
InputType
>
void
compute_ref_stats
(
NormType
norm_type
,
const
InputType
*
data
,
float
*
mu
,
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
const
double
epsilon
){
using
compute_t
=
float
;
#pragma omp parallel for proc_bind(spread)
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
sum
+=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
}
compute_t
m
;
if
(
norm_type
==
LayerNorm
){
mu
[
i
]
=
sum
/
H
;
m
=
mu
[
i
];
}
else
{
m
=
0
;}
compute_t
sum_sq
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_output
(
NormType
norm_type
,
const
InputType
*
data
,
const
InputType
*
gamma
,
const
InputType
*
beta
,
const
float
*
mu
,
const
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
OutputType
*
output
,
const
bool
zero_centered_gamma
){
using
compute_t
=
float
;
#pragma omp parallel for proc_bind(spread)
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
[
j
]);
if
(
zero_centered_gamma
)
{
g
+=
1.0
;
}
compute_t
tmp
;
if
(
norm_type
==
LayerNorm
)
{
tmp
=
(
current
-
mu
[
i
])
*
rsigma
[
i
]
*
g
+
static_cast
<
compute_t
>
(
beta
[
j
]);
}
else
{
// RMSNorm
tmp
=
current
*
rsigma
[
i
]
*
g
;
}
output
[
i
*
H
+
j
]
=
tmp
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
template
<
typename
InputType
,
typename
OutputType
>
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
NormType
norm_type
,
bool
is_training
)
{
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
NormType
norm_type
,
bool
is_training
,
const
bool
zero_centered_gamma_in_weight_dtype
)
{
cudaDeviceProp
prop
;
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
0
);
cudaGetDeviceProperties
(
&
prop
,
0
);
...
@@ -199,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -199,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform
(
&
gamma
);
fillUniform
(
&
gamma
);
fillUniform
(
&
beta
);
fillUniform
(
&
beta
);
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
true
);
}
else
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
// Forward kernel
// Forward kernel
float
epsilon
=
1e-5
;
float
epsilon
=
1e-5
;
if
(
norm_type
==
NormType
::
LayerNorm
){
if
(
norm_type
==
NormType
::
LayerNorm
){
...
@@ -224,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -224,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
0
);
0
);
}
}
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
Tensor
dequantized_output
(
"dequantized_output"
,
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
Tensor
dequantized_output
(
"dequantized_output"
,
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
dequantize_2x
<
OutputType
,
fp8e8m0
>
(
z
,
dequantized_output
,
is_training
);
dequantize_2x
<
OutputType
,
fp8e8m0
>
(
z
,
dequantized_output
,
is_training
);
...
@@ -250,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
...
@@ -250,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
compute_ref_output
(
norm_type
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
compute_ref_output
(
norm_type
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
gamma
.
rowwise_cpu_dptr
<
WeightType
>
(),
gamma
.
rowwise_cpu_dptr
<
WeightType
>
(),
beta
.
rowwise_cpu_dptr
<
WeightType
>
(),
beta
.
rowwise_cpu_dptr
<
WeightType
>
(),
ref_output
.
get
(),
ref_mu_ptr
,
ref_mu_ptr
,
ref_rsigma_ptr
,
ref_rsigma_ptr
,
N
,
H
,
N
,
H
,
ref_output
.
get
(),
nullptr
,
// amax
zero_centered_gamma
);
1.
f
,
// scale
zero_centered_gamma
,
true
,
// CuDNN is the only MXFP8 backend currently
zero_centered_gamma_in_weight_dtype
);
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
auto
err
=
cudaGetLastError
();
...
@@ -302,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
...
@@ -302,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
std
::
pair
<
size_t
,
size_t
>
,
std
::
pair
<
size_t
,
size_t
>
,
bool
,
bool
>>
{};
bool
,
bool
,
bool
>>
{};
TEST_P
(
MxNormTestSuite
,
TestMxNorm
)
{
TEST_P
(
MxNormTestSuite
,
TestMxNorm
)
{
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
...
@@ -314,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) {
...
@@ -314,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) {
const
auto
size
=
std
::
get
<
3
>
(
GetParam
());
const
auto
size
=
std
::
get
<
3
>
(
GetParam
());
const
bool
zero_centered_gamma
=
std
::
get
<
4
>
(
GetParam
());
const
bool
zero_centered_gamma
=
std
::
get
<
4
>
(
GetParam
());
const
bool
is_training
=
std
::
get
<
5
>
(
GetParam
());
const
bool
is_training
=
std
::
get
<
5
>
(
GetParam
());
const
bool
zero_centered_gamma_in_weight_dtype
=
std
::
get
<
6
>
(
GetParam
());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
is_training
);
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
is_training
,
zero_centered_gamma_in_weight_dtype
);
);
);
);
);
}
}
...
@@ -331,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P(
...
@@ -331,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P(
::
testing
::
Values
(
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
),
::
testing
::
Values
(
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
),
::
testing
::
ValuesIn
(
test_cases
),
::
testing
::
ValuesIn
(
test_cases
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
Values
(
true
,
false
)),
::
testing
::
Values
(
true
,
false
)),
[](
const
testing
::
TestParamInfo
<
MxNormTestSuite
::
ParamType
>&
info
)
{
[](
const
testing
::
TestParamInfo
<
MxNormTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
normToString
.
at
(
std
::
get
<
0
>
(
info
.
param
))
+
"_"
+
std
::
string
name
=
normToString
.
at
(
std
::
get
<
0
>
(
info
.
param
))
+
"_"
+
...
@@ -339,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P(
...
@@ -339,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P(
std
::
to_string
(
std
::
get
<
3
>
(
info
.
param
).
first
)
+
"X"
+
std
::
to_string
(
std
::
get
<
3
>
(
info
.
param
).
first
)
+
"X"
+
std
::
to_string
(
std
::
get
<
3
>
(
info
.
param
).
second
)
+
"X"
+
std
::
to_string
(
std
::
get
<
3
>
(
info
.
param
).
second
)
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
))
+
"out"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
))
+
"out"
+
std
::
to_string
(
int
(
std
::
get
<
5
>
(
info
.
param
))
+
1
)
+
"x"
;
std
::
to_string
(
int
(
std
::
get
<
5
>
(
info
.
param
))
+
1
)
+
"x"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
));
return
name
;
return
name
;
});
});
tests/cpp/test_common.cu
View file @
ab3e5a92
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include <algorithm>
#include <algorithm>
#include <memory>
#include <memory>
#include <random>
#include <random>
#include <iostream>
#include <cassert>
#include <cassert>
#include <cmath>
#include <cmath>
#include <string>
#include <string>
...
@@ -111,8 +112,8 @@ struct scale_inv_meta {
...
@@ -111,8 +112,8 @@ struct scale_inv_meta {
size_t
type_size
;
size_t
type_size
;
};
};
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>&
s
hape
)
{
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>&
s
)
{
return
{
shape
.
data
(),
s
hape
.
size
()
}
;
return
nvte_make_
shape
(
s
.
data
(),
s
.
size
()
)
;
}
}
std
::
pair
<
scale_inv_meta
,
scale_inv_meta
>
get_scales
(
const
NVTEShape
&
shape
,
std
::
pair
<
scale_inv_meta
,
scale_inv_meta
>
get_scales
(
const
NVTEShape
&
shape
,
...
@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...
@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta
ret_rowwise
,
ret_colwise
;
scale_inv_meta
ret_rowwise
,
ret_colwise
;
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
{
{
auto
alignment
=
block_alignment
[
0
];
auto
alignment
=
block_alignment
[
0
];
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
1
];
alignment
=
block_alignment
[
1
];
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
}
{
{
auto
alignment
=
block_alignment
[
1
];
auto
alignment
=
block_alignment
[
1
];
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
0
];
alignment
=
block_alignment
[
0
];
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
}
ret_rowwise
.
type
=
DType
::
kFloat8E8M0
;
ret_rowwise
.
type
=
DType
::
kFloat8E8M0
;
...
@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
...
@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
return
{
ret_rowwise
,
ret_colwise
};
return
{
ret_rowwise
,
ret_colwise
};
}
}
if
(
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
shape_vec
.
push_back
(
shape
.
data
[
i
]);
}
size_t
first_dim
=
first_dimension
(
shape_vec
);
size_t
last_dim
=
last_dimension
(
shape_vec
);
scale_inv_meta
ret_rowwise
,
ret_colwise
;
{
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
)),
4
)
*
4
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
{
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
)),
4
)
*
4
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_colwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type_size
=
sizeof
(
float
);
ret_colwise
.
type_size
=
sizeof
(
float
);
return
{
ret_rowwise
,
ret_colwise
};
}
if
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
shape_vec
.
push_back
(
shape
.
data
[
i
]);
}
size_t
first_dim
=
first_dimension
(
shape_vec
);
size_t
last_dim
=
last_dimension
(
shape_vec
);
scale_inv_meta
ret_rowwise
,
ret_colwise
;
{
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
first_dim
,
4
)
*
4
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
{
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
last_dim
,
4
)
*
4
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_colwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type_size
=
sizeof
(
float
);
ret_colwise
.
type_size
=
sizeof
(
float
);
return
{
ret_rowwise
,
ret_colwise
};
}
NVTE_ERROR
(
"Invalid scaling mode!"
);
NVTE_ERROR
(
"Invalid scaling mode!"
);
}
}
...
@@ -195,10 +240,10 @@ Tensor::Tensor(const std::string& name,
...
@@ -195,10 +240,10 @@ Tensor::Tensor(const std::string& name,
std
::
vector
<
size_t
>
normalized_shape_v
=
{
product
(
shape
,
0
,
shape
.
ndim
-
1
),
std
::
vector
<
size_t
>
normalized_shape_v
=
{
product
(
shape
,
0
,
shape
.
ndim
-
1
),
shape
.
data
[
shape
.
ndim
-
1
]};
shape
.
data
[
shape
.
ndim
-
1
]};
NVTEShape
normalized_shape
=
convertShape
(
normalized_shape_v
);
NVTEShape
normalized_shape
=
convertShape
(
normalized_shape_v
);
NVTEShape
columnwise_shape
{
nullptr
,
0
};
NVTEShape
columnwise_shape
=
{
};
std
::
vector
<
size_t
>
columnwise_shape_vec
;
std
::
vector
<
size_t
>
columnwise_shape_vec
;
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
||
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
// Transpose when tensor scaling
// Transpose when tensor scaling
columnwise_shape_vec
.
emplace_back
(
shape
.
data
[
shape
.
ndim
-
1
]);
columnwise_shape_vec
.
emplace_back
(
shape
.
data
[
shape
.
ndim
-
1
]);
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
-
1
;
++
i
)
{
...
@@ -212,8 +257,7 @@ Tensor::Tensor(const std::string& name,
...
@@ -212,8 +257,7 @@ Tensor::Tensor(const std::string& name,
}
}
if
(
columnwise
)
{
if
(
columnwise
)
{
columnwise_shape
.
data
=
columnwise_shape_vec
.
data
();
columnwise_shape
=
nvte_make_shape
(
columnwise_shape_vec
.
data
(),
columnwise_shape_vec
.
size
());
columnwise_shape
.
ndim
=
columnwise_shape_vec
.
size
();
}
}
tensor_
=
TensorWrapper
(
scaling_mode
);
tensor_
=
TensorWrapper
(
scaling_mode
);
...
@@ -259,25 +303,27 @@ Tensor::Tensor(const std::string& name,
...
@@ -259,25 +303,27 @@ Tensor::Tensor(const std::string& name,
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
sizeof
(
float
),
0
);
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
sizeof
(
float
),
0
);
}
}
}
else
{
}
else
{
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
normalized_shape
,
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
tensor_
.
scaling_mode
());
get_scales
(
normalized_shape
,
tensor_
.
scaling_mode
());
auto
rowwise_scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
rowwise_scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
columnwise_scale_size
=
product
(
colwise_scale_meta
.
shape
)
*
colwise_scale_meta
.
type_size
;
auto
columnwise_scale_size
=
product
(
colwise_scale_meta
.
shape
)
*
colwise_scale_meta
.
type_size
;
auto
scale_shape
=
rowwise_scale_meta
.
shape
;
auto
scale_shape
=
rowwise_scale_meta
.
shape
;
auto
columnwise_scale_shape
=
colwise_scale_meta
.
shape
;
auto
columnwise_scale_shape
=
colwise_scale_meta
.
shape
;
if
(
rowwise
)
{
if
(
rowwise
)
{
cudaMalloc
((
void
**
)
&
rowwise_scale_inv
,
rowwise_scale_size
);
// NOLINT(*)
cudaMalloc
((
void
**
)
&
rowwise_scale_inv
,
rowwise_scale_size
);
// NOLINT(*)
cudaMemset
(
rowwise_scale_inv
,
0
,
rowwise_scale_size
);
cudaMemset
(
rowwise_scale_inv
,
0
,
rowwise_scale_size
);
rowwise_scale_inv_cpu_data_
=
std
::
make_unique
<
unsigned
char
[]
>
(
rowwise_scale_size
);
rowwise_scale_inv_cpu_data_
=
std
::
make_unique
<
unsigned
char
[]
>
(
rowwise_scale_size
);
std
::
fill_n
(
rowwise_scale_inv_cpu_data_
.
get
(),
rowwise_scale_size
,
0
);
std
::
fill_n
(
rowwise_scale_inv_cpu_data_
.
get
(),
rowwise_scale_size
,
0
);
tensor_
.
set_rowwise_scale_inv
(
rowwise_scale_inv
,
DType
::
kFloat8E8M0
,
scale_shape
);
auto
scale_dtype
=
rowwise_scale_meta
.
type
;
tensor_
.
set_rowwise_scale_inv
(
rowwise_scale_inv
,
scale_dtype
,
scale_shape
);
}
}
if
(
columnwise
)
{
if
(
columnwise
)
{
cudaMalloc
((
void
**
)
&
columnwise_scale_inv
,
columnwise_scale_size
);
// NOLINT(*)
cudaMalloc
((
void
**
)
&
columnwise_scale_inv
,
columnwise_scale_size
);
// NOLINT(*)
cudaMemset
(
columnwise_scale_inv
,
0
,
columnwise_scale_size
);
cudaMemset
(
columnwise_scale_inv
,
0
,
columnwise_scale_size
);
columnwise_scale_inv_cpu_data_
=
std
::
make_unique
<
unsigned
char
[]
>
(
columnwise_scale_size
);
columnwise_scale_inv_cpu_data_
=
std
::
make_unique
<
unsigned
char
[]
>
(
columnwise_scale_size
);
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
columnwise_scale_size
,
0
);
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
columnwise_scale_size
,
0
);
tensor_
.
set_columnwise_scale_inv
(
columnwise_scale_inv
,
DType
::
kFloat8E8M0
,
columnwise_scale_shape
);
auto
scale_dtype
=
colwise_scale_meta
.
type
;
tensor_
.
set_columnwise_scale_inv
(
columnwise_scale_inv
,
scale_dtype
,
columnwise_scale_shape
);
}
}
}
}
}
}
...
@@ -311,7 +357,8 @@ void Tensor::to_cpu() const {
...
@@ -311,7 +357,8 @@ void Tensor::to_cpu() const {
sizeof
(
float
),
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
cudaMemcpyDeviceToHost
);
}
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
if
(
rowwise_
)
{
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
cudaMemcpy
(
rowwise_scale_inv_cpu_data_
.
get
(),
cudaMemcpy
(
rowwise_scale_inv_cpu_data_
.
get
(),
...
@@ -349,7 +396,8 @@ void Tensor::from_cpu() const {
...
@@ -349,7 +396,8 @@ void Tensor::from_cpu() const {
cudaMemcpy
(
tensor_
.
scale
(),
scale_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpy
(
tensor_
.
scale
(),
scale_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
cudaMemcpyHostToDevice
);
}
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
if
(
rowwise_
)
{
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
cudaMemcpy
(
tensor_
.
get_rowwise_scale_inv
().
data_ptr
,
cudaMemcpy
(
tensor_
.
get_rowwise_scale_inv
().
data_ptr
,
...
@@ -383,27 +431,29 @@ void Tensor::set_scale_inv(float scale_inv) {
...
@@ -383,27 +431,29 @@ void Tensor::set_scale_inv(float scale_inv) {
if
(
columnwise_
)
{
if
(
columnwise_
)
{
NVTE_CHECK
(
columnwise_scale_inv_cpu_data_
);
NVTE_CHECK
(
columnwise_scale_inv_cpu_data_
);
}
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
tensor_
.
shape
(),
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
tensor_
.
shape
(),
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
if
(
rowwise_
)
{
auto
num_scales
=
product
(
rowwise_scale_meta
.
shape
);
auto
num_scales
=
product
(
rowwise_scale_meta
.
shape
);
if
(
num_scales
==
1
){
if
(
num_scales
==
1
)
{
rowwise_cpu_scale_inv_ptr
<
float
>
()[
0
]
=
scale_inv
;
rowwise_cpu_scale_inv_ptr
<
float
>
()[
0
]
=
scale_inv
;
}
else
{
}
else
{
std
::
uniform_int_distribution
<
uint8_t
>
dis
(
0
,
127
);
std
::
uniform_int_distribution
<
uint8_t
>
dis
(
0
,
127
);
auto
*
scale_inv_ptr
=
rowwise_cpu_scale_inv_ptr
<
uint8_t
>
();
auto
*
scale_inv_ptr
=
rowwise_cpu_scale_inv_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
){
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
)
{
scale_inv_ptr
[
i
]
=
dis
(
gen_
);
scale_inv_ptr
[
i
]
=
dis
(
gen_
);
}
}
}
}
}
}
if
(
columnwise_
)
{
if
(
columnwise_
)
{
auto
num_scales
=
product
(
colwise_scale_meta
.
shape
);
auto
num_scales
=
product
(
colwise_scale_meta
.
shape
);
if
(
num_scales
==
1
){
if
(
num_scales
==
1
)
{
columnwise_cpu_scale_inv_ptr
<
float
>
()[
0
]
=
scale_inv
;
columnwise_cpu_scale_inv_ptr
<
float
>
()[
0
]
=
scale_inv
;
}
else
{
}
else
{
std
::
uniform_int_distribution
<
uint8_t
>
dis
(
0
,
127
);
std
::
uniform_int_distribution
<
uint8_t
>
dis
(
0
,
127
);
auto
*
scale_inv_ptr
=
columnwise_cpu_scale_inv_ptr
<
uint8_t
>
();
auto
*
scale_inv_ptr
=
columnwise_cpu_scale_inv_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
){
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
)
{
scale_inv_ptr
[
i
]
=
dis
(
gen_
);
scale_inv_ptr
[
i
]
=
dis
(
gen_
);
}
}
}
}
...
@@ -413,23 +463,20 @@ void Tensor::set_scale_inv(float scale_inv) {
...
@@ -413,23 +463,20 @@ void Tensor::set_scale_inv(float scale_inv) {
}
}
void
Tensor
::
shareFP8Meta
(
const
Tensor
&
other
)
{
void
Tensor
::
shareFP8Meta
(
const
Tensor
&
other
)
{
if
(
isFp8Type
(
dtype
())
&&
isFp8Type
(
other
.
dtype
()))
{
if
(
isFp8Type
(
dtype
())
&&
isFp8Type
(
other
.
dtype
()))
{
auto
new_tensor
=
TensorWrapper
(
other
.
tensor_
.
scaling_mode
());
auto
new_tensor
=
TensorWrapper
(
other
.
tensor_
.
scaling_mode
());
auto
my_rowwise_data
=
tensor_
.
get_rowwise_data
();
auto
my_rowwise_data
=
tensor_
.
get_rowwise_data
();
new_tensor
.
set_rowwise_data
(
my_rowwise_data
.
data_ptr
,
new_tensor
.
set_rowwise_data
(
my_rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
my_rowwise_data
.
dtype
),
static_cast
<
DType
>
(
my_rowwise_data
.
dtype
),
my_rowwise_data
.
shape
);
my_rowwise_data
.
shape
);
auto
my_columnwise_data
=
tensor_
.
get_columnwise_data
();
auto
my_columnwise_data
=
tensor_
.
get_columnwise_data
();
new_tensor
.
set_columnwise_data
(
my_columnwise_data
.
data_ptr
,
new_tensor
.
set_columnwise_data
(
my_columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
my_columnwise_data
.
dtype
),
static_cast
<
DType
>
(
my_columnwise_data
.
dtype
),
my_columnwise_data
.
shape
);
my_columnwise_data
.
shape
);
auto
other_amax
=
other
.
tensor_
.
get_amax
();
auto
other_amax
=
other
.
tensor_
.
get_amax
();
new_tensor
.
set_amax
(
other_amax
.
data_ptr
,
new_tensor
.
set_amax
(
other_amax
.
data_ptr
,
static_cast
<
DType
>
(
other_amax
.
dtype
),
static_cast
<
DType
>
(
other_amax
.
dtype
),
other_amax
.
shape
);
other_amax
.
shape
);
auto
other_scale
=
other
.
tensor_
.
get_scale
();
auto
other_scale
=
other
.
tensor_
.
get_scale
();
new_tensor
.
set_scale
(
other_scale
.
data_ptr
,
new_tensor
.
set_scale
(
other_scale
.
data_ptr
,
static_cast
<
DType
>
(
other_scale
.
dtype
),
static_cast
<
DType
>
(
other_scale
.
dtype
),
other_scale
.
shape
);
other_scale
.
shape
);
auto
other_row_scale_inv
=
other
.
tensor_
.
get_rowwise_scale_inv
();
auto
other_row_scale_inv
=
other
.
tensor_
.
get_rowwise_scale_inv
();
new_tensor
.
set_rowwise_scale_inv
(
other_row_scale_inv
.
data_ptr
,
new_tensor
.
set_rowwise_scale_inv
(
other_row_scale_inv
.
data_ptr
,
...
@@ -460,9 +507,7 @@ std::string to_string(const std::vector<T> &v) {
...
@@ -460,9 +507,7 @@ std::string to_string(const std::vector<T> &v) {
std
::
vector
<
size_t
>
unravel
(
const
size_t
i
,
const
NVTEShape
&
shape
)
{
std
::
vector
<
size_t
>
unravel
(
const
size_t
i
,
const
NVTEShape
&
shape
)
{
std
::
vector
<
size_t
>
ret
;
std
::
vector
<
size_t
>
ret
;
size_t
current_i
=
i
;
size_t
current_i
=
i
;
for
(
size_t
current
=
shape
.
ndim
-
1
;
for
(
size_t
current
=
shape
.
ndim
-
1
;
current
>
0
;
--
current
)
{
current
>
0
;
--
current
)
{
ret
.
push_back
(
current_i
%
shape
.
data
[
current
]);
ret
.
push_back
(
current_i
%
shape
.
data
[
current
]);
current_i
/=
shape
.
data
[
current
];
current_i
/=
shape
.
data
[
current
];
}
}
...
@@ -812,8 +857,7 @@ bool isFp8Type(DType type) {
...
@@ -812,8 +857,7 @@ bool isFp8Type(DType type) {
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
||
type
==
DType
::
kFloat8E8M0
;
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
||
type
==
DType
::
kFloat8E8M0
;
}
}
int32_t
getDeviceComputeCapability
()
int32_t
getDeviceComputeCapability
()
{
{
cudaDeviceProp
deviceProp
;
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
return
10
*
deviceProp
.
major
+
deviceProp
.
minor
;
return
10
*
deviceProp
.
major
+
deviceProp
.
minor
;
...
...
tests/cpp/test_common.h
View file @
ab3e5a92
...
@@ -121,7 +121,7 @@ class Tensor {
...
@@ -121,7 +121,7 @@ class Tensor {
const
bool
rowwise
=
true
,
const
bool
rowwise
=
true
,
const
bool
columnwise
=
false
,
const
bool
columnwise
=
false
,
const
NVTEScalingMode
&
mode
=
NVTE_DELAYED_TENSOR_SCALING
)
:
const
NVTEScalingMode
&
mode
=
NVTE_DELAYED_TENSOR_SCALING
)
:
Tensor
(
name
,
NVTES
hape
{
shape
.
data
(),
shape
.
size
()
}
,
type
,
rowwise
,
columnwise
,
mode
)
{}
Tensor
(
name
,
nvte_make_s
hape
(
shape
.
data
(),
shape
.
size
()
)
,
type
,
rowwise
,
columnwise
,
mode
)
{}
Tensor
()
{}
Tensor
()
{}
...
@@ -148,25 +148,19 @@ class Tensor {
...
@@ -148,25 +148,19 @@ class Tensor {
if
(
scale_inv
!=
nullptr
)
{
if
(
scale_inv
!=
nullptr
)
{
cudaFree
(
scale_inv
);
cudaFree
(
scale_inv
);
}
}
if
(
columnwise_data_ptr
!=
nullptr
){
if
(
columnwise_data_ptr
!=
nullptr
)
{
cudaFree
(
columnwise_data_ptr
);
cudaFree
(
columnwise_data_ptr
);
}
}
if
(
columnwise_scale_inv
!=
nullptr
){
if
(
columnwise_scale_inv
!=
nullptr
)
{
cudaFree
(
columnwise_scale_inv
);
cudaFree
(
columnwise_scale_inv
);
}
}
}
}
NVTETensor
data
()
const
noexcept
{
NVTETensor
data
()
const
noexcept
{
return
tensor_
.
data
();
}
return
tensor_
.
data
();
}
NVTEShape
rowwise_shape
()
const
noexcept
{
NVTEShape
rowwise_shape
()
const
noexcept
{
return
tensor_
.
get_rowwise_data
().
shape
;
}
return
tensor_
.
get_rowwise_data
().
shape
;
}
NVTEShape
columnwise_shape
()
const
noexcept
{
NVTEShape
columnwise_shape
()
const
noexcept
{
return
tensor_
.
get_columnwise_data
().
shape
;
}
return
tensor_
.
get_columnwise_data
().
shape
;
}
NVTEShape
rowwise_scale_inv_shape
()
const
{
NVTEShape
rowwise_scale_inv_shape
()
const
{
NVTE_CHECK
(
rowwise_
,
"Tensor does not have rowwise data!"
);
NVTE_CHECK
(
rowwise_
,
"Tensor does not have rowwise data!"
);
...
@@ -233,6 +227,8 @@ class Tensor {
...
@@ -233,6 +227,8 @@ class Tensor {
T
*
rowwise_cpu_scale_inv_ptr
(){
T
*
rowwise_cpu_scale_inv_ptr
(){
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
){
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
){
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
if
(
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_1D
||
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_2D
)
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
{
}
else
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kByte
,
"Invalid type!"
);
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kByte
,
"Invalid type!"
);
}
}
...
@@ -244,6 +240,8 @@ class Tensor {
...
@@ -244,6 +240,8 @@ class Tensor {
T
*
columnwise_cpu_scale_inv_ptr
(){
T
*
columnwise_cpu_scale_inv_ptr
(){
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
){
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
){
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
if
(
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_1D
||
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_2D
)
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
{
}
else
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kByte
,
"Invalid type!"
);
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kByte
,
"Invalid type!"
);
}
}
...
@@ -475,6 +473,7 @@ extern std::vector<DType> all_fp_types;
...
@@ -475,6 +473,7 @@ extern std::vector<DType> all_fp_types;
bool
isFp8Type
(
DType
type
);
bool
isFp8Type
(
DType
type
);
int32_t
getDeviceComputeCapability
();
int32_t
getDeviceComputeCapability
();
constexpr
int32_t
hopperComputeCapability
=
90
;
constexpr
int32_t
blackwellComputeCapability
=
100
;
constexpr
int32_t
blackwellComputeCapability
=
100
;
}
// namespace test
}
// namespace test
...
...
tests/jax/pytest.ini
View file @
ab3e5a92
...
@@ -25,3 +25,5 @@ filterwarnings=
...
@@ -25,3 +25,5 @@ filterwarnings=
ignore:jax.experimental.maps
and
.*
are
deprecated.*:DeprecationWarning
ignore:jax.experimental.maps
and
.*
are
deprecated.*:DeprecationWarning
ignore:The
host_callback
APIs
are
deprecated
.*:DeprecationWarning
ignore:The
host_callback
APIs
are
deprecated
.*:DeprecationWarning
ignore:Scan
loop
is
disabled
for
fused
ring
attention.*:UserWarning
ignore:Scan
loop
is
disabled
for
fused
ring
attention.*:UserWarning
ignore:jax.extend.ffi.register_ffi_target
is
deprecated
ignore:jax.extend.ffi.ffi_lowering
is
deprecated
tests/jax/test_custom_call_compute.py
View file @
ab3e5a92
...
@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import (
...
@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import (
ScaledTensor
,
ScaledTensor
,
ScalingMode
,
ScalingMode
,
QuantizerFactory
,
QuantizerFactory
,
Quantize
Axis
,
Quantize
Layout
,
)
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
from
transformer_engine.jax.activation
import
activation
...
@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
...
@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES
=
[(
256
,
128
),
(
128
,
256
)]
LN_CASES
=
[(
256
,
128
),
(
128
,
256
)]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
is_fp8_supported
,
reason
=
helper
.
is_fp8_available
()
is_fp8_supported
,
reason
=
helper
.
is_fp8_available
()
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
supported_scaling_modes
=
[]
supported_scaling_modes
=
[]
""" Find supported scaling modes"""
""" Find supported scaling modes"""
if
is_fp8_supported
:
if
is_fp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
if
is_mxfp8_supported
:
if
is_mxfp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
MXFP8_1D_SCALING
)
def
is_shape_supported_by_mxfp8
(
input_shape
):
def
is_shape_supported_by_mxfp8
(
input_shape
):
try
:
try
:
if
isinstance
(
input_shape
,
type
(
pytest
.
param
(
0
))):
if
isinstance
(
input_shape
,
type
(
pytest
.
param
(
0
))):
input_shape
=
input_shape
.
values
[
0
]
input_shape
=
input_shape
.
values
[
0
]
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
.
get_scale_shape_2x
(
input_shape
)
ScalingMode
.
MXFP8_1D_SCALING
.
get_scale_shape_2x
(
input_shape
)
return
True
return
True
except
:
except
:
# get_scale_shapes will raise an exception if the shape is not supported
# get_scale_shapes will raise an exception if the shape is not supported
...
@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
...
@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
def
assert_dequantized_scaled_tensor
(
a
:
ScaledTensor
,
b
:
jnp
.
ndarray
):
def
assert_dequantized_scaled_tensor
(
a
:
ScaledTensor
,
b
:
jnp
.
ndarray
):
if
isinstance
(
a
,
ScaledTensor1x
):
if
isinstance
(
a
,
ScaledTensor1x
):
if
a
.
layout
==
"T"
:
if
a
.
data_layout
==
"T"
:
b_transpose
=
jnp
.
transpose
(
b
,
(
-
1
,
*
range
(
b
.
ndim
-
1
)))
flatten_axis
=
a
.
data
.
ndim
-
a
.
flatten_axis
b_transpose
=
jnp
.
transpose
(
b
,
(
*
range
(
flatten_axis
,
b
.
ndim
),
*
range
(
flatten_axis
)))
assert_allclose
(
a
.
dequantize
(),
b_transpose
,
dtype
=
a
.
data
.
dtype
)
assert_allclose
(
a
.
dequantize
(),
b_transpose
,
dtype
=
a
.
data
.
dtype
)
else
:
else
:
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
...
@@ -141,7 +142,8 @@ class TestActivation:
...
@@ -141,7 +142,8 @@ class TestActivation:
def
test_act_grad
(
self
,
shape
,
activation_type
):
def
test_act_grad
(
self
,
shape
,
activation_type
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
x
=
jax
.
random
.
uniform
(
key
,
shape
,
jnp
.
float32
)
x
=
jax
.
random
.
uniform
(
key
,
shape
,
jnp
.
float32
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
value_n_grad_primitive_func
=
jit
(
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
...
@@ -159,7 +161,8 @@ class TestActivation:
...
@@ -159,7 +161,8 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
x
=
random_inputs
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
self
.
activation_type
=
activation_type
value_n_grad_primitive_func
=
jit
(
value_n_grad_primitive_func
=
jit
(
...
@@ -167,9 +170,9 @@ class TestActivation:
...
@@ -167,9 +170,9 @@ class TestActivation:
)
)
quantizer
=
QuantizerFactory
.
create
(
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
output_type
,
q_dtype
=
output_type
,
q_
axis
=
Quantize
Axis
.
ROWWISE
,
q_
layout
=
Quantize
Layout
.
ROWWISE
,
)
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
quantizer
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
quantizer
)
...
@@ -182,19 +185,22 @@ class TestActivation:
...
@@ -182,19 +185,22 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_act_forward_with_delayed_scaling_fp8
(
def
test_act_forward_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_
axis
self
,
random_inputs
,
activation_type
,
output_type
,
q_
layout
):
):
x
=
random_inputs
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
self
.
activation_type
=
activation_type
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
n_quantizers
=
2
,
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
output_type
,
q_dtype
=
output_type
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
)
te_output
=
tex
.
act_lu
(
x
,
activation_type
,
te_quantizer
)
te_output
=
tex
.
act_lu
(
x
,
activation_type
,
te_quantizer
)
...
@@ -203,19 +209,21 @@ class TestActivation:
...
@@ -203,19 +209,21 @@ class TestActivation:
assert_bitwise_scaled_tensors
(
te_output
,
jax_output
)
assert_bitwise_scaled_tensors
(
te_output
,
jax_output
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
[(
128
,
128
)])
@
pytest_parametrize_wrapper
(
"shape"
,
[(
2
,
64
,
1
,
256
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_act_forward_with_block_scaling_fp8
(
def
test_act_forward_with_block_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_
axis
self
,
random_inputs
,
activation_type
,
output_type
,
q_
layout
):
):
x
=
random_inputs
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
self
.
activation_type
=
activation_type
quantizer
=
QuantizerFactory
.
create
(
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
,
q_dtype
=
output_type
,
q_
axis
=
q_axis
scaling_mode
=
ScalingMode
.
MXFP8_1D_SCALING
,
q_dtype
=
output_type
,
q_
layout
=
q_layout
)
)
output
=
tex
.
act_lu
(
x
,
activation_type
,
quantizer
)
output
=
tex
.
act_lu
(
x
,
activation_type
,
quantizer
)
...
@@ -324,9 +332,11 @@ class TestNorm:
...
@@ -324,9 +332,11 @@ class TestNorm:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# No Norm FWD E5M2 in TE backend
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_norm_grad_with_delayed_scaling_fp8
(
def
test_norm_grad_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
axis
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
layout
):
):
"""
"""
Test transformer_engine.jax.layernorm.layernorm
Test transformer_engine.jax.layernorm.layernorm
...
@@ -335,7 +345,9 @@ class TestNorm:
...
@@ -335,7 +345,9 @@ class TestNorm:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
quantizer
=
QuantizerFactory
.
create
(
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_axis
=
q_axis
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_layout
=
q_layout
,
)
)
self
.
_test_norm_grad
(
self
.
_test_norm_grad
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
...
@@ -351,7 +363,7 @@ class TestNorm:
...
@@ -351,7 +363,7 @@ class TestNorm:
inp_dtype
,
inp_dtype
,
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_
axis
,
q_
layout
,
):
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
...
@@ -363,7 +375,7 @@ class TestNorm:
...
@@ -363,7 +375,7 @@ class TestNorm:
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
quantizer
,
ref_quantizer
=
QuantizerFactory
.
create
(
quantizer
,
ref_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_
axis
=
q_axis
n_quantizers
=
2
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_
layout
=
q_layout
)
)
if
norm_type
==
"layernorm"
:
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
...
@@ -391,9 +403,11 @@ class TestNorm:
...
@@ -391,9 +403,11 @@ class TestNorm:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# No Norm FWD E5M2 in TE backend
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_norm_forward_with_delayed_scaling_fp8
(
def
test_norm_forward_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
axis
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
layout
):
):
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
...
@@ -406,8 +420,8 @@ class TestNorm:
...
@@ -406,8 +420,8 @@ class TestNorm:
epsilon
=
epsilon
,
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
...
@@ -423,8 +437,8 @@ class TestNorm:
...
@@ -423,8 +437,8 @@ class TestNorm:
epsilon
=
epsilon
,
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
,
scaling_mode
=
ScalingMode
.
MXFP8_1D_SCALING
,
q_
axis
=
Quantize
Axis
.
ROWWISE_COLWISE
,
q_
layout
=
Quantize
Layout
.
ROWWISE_COLWISE
,
)
)
...
@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = {
...
@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = {
}
}
ALL_QUANTIZE_TEST_SHAPES
=
[
ALL_QUANTIZE_TEST_SHAPES
=
[
(
128
,
128
),
(
32
,
64
),
(
4
,
256
,
51
2
),
(
2
,
64
,
3
2
),
]
]
QUANTIZE_TEST_SHAPES
=
{
QUANTIZE_TEST_SHAPES
=
{
"L0"
:
[
"L0"
:
[
(
256
,
128
),
(
32
,
256
,
128
),
(
64
,
16
,
2
,
256
),
(
64
,
32
,
3
2
,
256
),
],
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
}
}
...
@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = {
...
@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = {
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
@
pytest_parametrize_wrapper
(
@
pytest_parametrize_wrapper
(
"q_
axis
"
,
[
Quantize
Axis
.
ROWWISE
,
Quantize
Axis
.
COLWISE
,
Quantize
Axis
.
ROWWISE_COLWISE
]
"q_
layout
"
,
[
Quantize
Layout
.
ROWWISE
,
Quantize
Layout
.
COLWISE
,
Quantize
Layout
.
ROWWISE_COLWISE
]
)
)
class
TestQuantize
:
class
TestQuantize
:
"""
"""
Purely quantization related tests that will always test on a wider set of types and shapes
Purely quantization related tests that will always test on a wider set of types and shapes
"""
"""
def
test_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
def
test_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_
layout
,
flatten_
axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer
=
QuantizerFactory
.
create
(
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
q_dtype
=
q_dtype
,
q_dtype
=
q_dtype
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
)
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
x
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
x
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
scaled_tensor
=
quantizer
.
quantize
(
x
)
scaled_tensor
=
quantizer
.
quantize
(
x
,
flatten_axis
=
flatten_axis
)
assert_dequantized_scaled_tensor
(
scaled_tensor
,
x
)
assert_dequantized_scaled_tensor
(
scaled_tensor
,
x
)
def
test_quantize_bitwise
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
def
test_quantize_bitwise
(
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
input_shape
):
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
n_quantizers
=
2
,
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
)
)
jax_output
=
_jax_quantize
(
input
,
quantizer
=
jax_quantizer
)
jax_output
=
_jax_quantize
(
input
,
quantizer
=
jax_quantizer
,
flatten_axis
=
flatten_axis
)
te_output
=
tex
.
quantize
(
input
,
quantizer
=
te_quantizer
)
te_output
=
tex
.
quantize
(
input
,
quantizer
=
te_quantizer
,
flatten_axis
=
flatten_axis
)
assert_bitwise_scaled_tensors
(
jax
_output
,
te
_output
)
assert_bitwise_scaled_tensors
(
te
_output
,
jax
_output
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
...
@@ -508,10 +526,14 @@ class TestFusedQuantize:
...
@@ -508,10 +526,14 @@ class TestFusedQuantize:
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_axis
):
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
transpose_axis
=
-
1
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
input_shape
input_shape
):
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
...
@@ -520,35 +542,37 @@ class TestFusedQuantize:
...
@@ -520,35 +542,37 @@ class TestFusedQuantize:
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
)
)
te_output
,
te_dbias
=
jit
(
lambda
input
:
tex
.
quantize_dbias
(
input
,
quantizer
=
te_quantizer
))(
te_output
,
te_dbias
=
jit
(
input
lambda
input
:
tex
.
quantize_dbias
(
input
,
quantizer
=
te_quantizer
,
flatten_axis
=
flatten_axis
)
)
)(
input
)
jax_output
,
jax_dbias
=
jit
(
jax_output
,
jax_dbias
=
jit
(
lambda
input
:
_jax_quantize_dbias
(
lambda
input
:
_jax_quantize_dbias
(
input
,
input
,
quantizer
=
jax_quantizer
,
flatten_axis
=
flatten_axis
quantizer
=
jax_quantizer
,
)
)
)(
input
)
)(
input
)
assert_bitwise_scaled_tensors
(
jax
_output
,
te
_output
)
assert_bitwise_scaled_tensors
(
te
_output
,
jax
_output
)
assert_allclose
(
jax
_dbias
,
te
_dbias
)
assert_allclose
(
te
_dbias
,
jax
_dbias
)
def
_test_quantize_dact_dbias
(
def
_test_quantize_dact_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
activation_type
,
is_dbias
,
q_
axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
activation_type
,
is_dbias
,
q_
layout
):
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
input_shape
,
in_dtype
,
-
1
,
1
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
input_shape
,
in_dtype
,
-
1
,
1
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
dz
=
jax
.
random
.
uniform
(
subkeys
[
1
],
input_shape
,
in_dtype
,
-
1
,
1
)
dz
=
jax
.
random
.
uniform
(
subkeys
[
1
],
input_shape
,
in_dtype
,
-
1
,
1
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
)
)
is_casted_output
=
te_quantizer
is
not
None
is_casted_output
=
te_quantizer
is
not
None
...
@@ -573,12 +597,12 @@ class TestFusedQuantize:
...
@@ -573,12 +597,12 @@ class TestFusedQuantize:
)(
dz
,
x
)
)(
dz
,
x
)
if
is_casted_output
:
if
is_casted_output
:
assert_bitwise_scaled_tensors
(
jax
_output
,
te
_output
)
assert_bitwise_scaled_tensors
(
te
_output
,
jax
_output
)
else
:
else
:
assert_allclose
(
jax
_output
,
te
_output
)
assert_allclose
(
te
_output
,
jax
_output
)
if
is_dbias
:
if
is_dbias
:
assert_allclose
(
jax
_dbias
,
te
_dbias
)
assert_allclose
(
te
_dbias
,
jax
_dbias
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
...
@@ -594,10 +618,10 @@ class TestFusedQuantize:
...
@@ -594,10 +618,10 @@ class TestFusedQuantize:
in_dtype
=
in_dtype
,
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
input_shape
=
input_shape
,
out_dtype
=
in_dtype
,
out_dtype
=
in_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
NO_SCALING
,
scaling_mode
=
ScalingMode
.
NO_SCALING
,
activation_type
=
activation_type
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
q_
axis
=
Quantize
Axis
.
ROWWISE
,
q_
layout
=
Quantize
Layout
.
ROWWISE
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
@@ -605,18 +629,20 @@ class TestFusedQuantize:
...
@@ -605,18 +629,20 @@ class TestFusedQuantize:
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_quantize_dact_dbias_delayed_scaling
(
def
test_quantize_dact_dbias_delayed_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
layout
):
):
self
.
_test_quantize_dact_dbias
(
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
activation_type
=
activation_type
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
...
@@ -626,9 +652,11 @@ class TestFusedQuantize:
...
@@ -626,9 +652,11 @@ class TestFusedQuantize:
)
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_quantize_dact_dbias_mxfp8_scaling
(
def
test_quantize_dact_dbias_mxfp8_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
layout
):
):
if
reduce
(
operator
.
mul
,
input_shape
[:
-
1
])
%
128
!=
0
or
input_shape
[
-
1
]
%
128
!=
0
:
if
reduce
(
operator
.
mul
,
input_shape
[:
-
1
])
%
128
!=
0
or
input_shape
[
-
1
]
%
128
!=
0
:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
...
@@ -642,78 +670,78 @@ class TestFusedQuantize:
...
@@ -642,78 +670,78 @@ class TestFusedQuantize:
in_dtype
=
in_dtype
,
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
,
scaling_mode
=
ScalingMode
.
MXFP8_1D_SCALING
,
activation_type
=
activation_type
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
)
class
TestDense
:
class
TestDense
:
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
layout
):
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
data_
layout
):
if
layout
[
0
]
==
"T"
:
if
data_
layout
[
0
]
==
"T"
:
a
=
jnp
.
swapaxes
(
a
,
-
1
,
-
2
)
a
=
jnp
.
swapaxes
(
a
,
-
1
,
-
2
)
if
layout
[
1
]
==
"T"
:
if
data_
layout
[
1
]
==
"T"
:
b
=
jnp
.
swapaxes
(
b
,
-
1
,
-
2
)
b
=
jnp
.
swapaxes
(
b
,
-
1
,
-
2
)
return
jnp
.
dot
(
a
,
b
)
return
jnp
.
dot
(
a
,
b
)
def
_generate_gemm_input
(
self
,
m
,
n
,
k
,
layout
):
def
_generate_gemm_input
(
self
,
m
,
n
,
k
,
data_
layout
):
key
=
jax
.
random
.
PRNGKey
(
0
)
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
subkeys
[
0
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
(
m
if
data_
layout
[
0
]
==
"N"
else
k
,
k
if
data_
layout
[
0
]
==
"N"
else
m
),
dtype
=
jnp
.
bfloat16
,
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
k
)
)
/
jnp
.
sqrt
(
k
)
w
=
jax
.
random
.
uniform
(
w
=
jax
.
random
.
uniform
(
subkeys
[
1
],
subkeys
[
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
(
k
if
data_
layout
[
1
]
==
"N"
else
n
,
n
if
data_
layout
[
1
]
==
"N"
else
k
),
dtype
=
jnp
.
bfloat16
,
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
n
)
)
/
jnp
.
sqrt
(
n
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
lhs_contracting_dim
=
(
1
,)
if
data_
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
rhs_contracting_dim
=
(
0
,)
if
data_
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
return
(
x
,
w
,
contracting_dims
)
return
(
x
,
w
,
contracting_dims
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
@
pytest_parametrize_wrapper
(
"
data_
layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
layout
):
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
data_
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
@
pytest_parametrize_wrapper
(
"
data_
layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
layout
):
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
data_
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
quantizer_set
=
QuantizerFactory
.
create_set
(
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
)
)
primitive_out
=
tex
.
gemm
(
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
)
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
layout
=
"NN"
data_
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
def
primitive_func
(
x
,
w
,
contracting_dims
):
def
primitive_func
(
x
,
w
,
contracting_dims
):
primitive_out
=
dense
(
x
,
w
,
contracting_dims
=
contracting_dims
)
primitive_out
=
dense
(
x
,
w
,
contracting_dims
=
contracting_dims
)
return
jnp
.
mean
(
primitive_out
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
w
,
layout
):
def
ref_func
(
x
,
w
,
data_
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
))
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
))
...
@@ -722,19 +750,19 @@ class TestDense:
...
@@ -722,19 +750,19 @@ class TestDense:
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
)
=
value_n_grad_primitive_func
(
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
)
=
value_n_grad_primitive_func
(
x
,
w
,
contracting_dims
x
,
w
,
contracting_dims
)
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
layout
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
data_
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
layout
=
"NN"
data_
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
key
=
jax
.
random
.
PRNGKey
(
1
)
key
=
jax
.
random
.
PRNGKey
(
1
)
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
jnp
.
bfloat16
)
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
jnp
.
bfloat16
)
...
@@ -745,9 +773,9 @@ class TestDense:
...
@@ -745,9 +773,9 @@ class TestDense:
)
)
return
jnp
.
mean
(
primitive_out
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
w
,
bias
,
layout
):
def
ref_func
(
x
,
w
,
bias
,
data_
layout
):
return
jnp
.
mean
(
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
+
jnp
.
expand_dims
(
bias
,
axis
=
0
)
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
)
+
jnp
.
expand_dims
(
bias
,
axis
=
0
)
)
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
...
@@ -757,13 +785,15 @@ class TestDense:
...
@@ -757,13 +785,15 @@ class TestDense:
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
)
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
)
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_bias_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
bias
,
layout
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_bias_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
bias
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
...
@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
...
@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class
TestFusedDense
:
class
TestFusedDense
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
...
@@ -800,7 +830,7 @@ class TestFusedDense:
...
@@ -800,7 +830,7 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
Test layernorm_dense VJP Rule
"""
"""
# No Norm FWD E5M2 in TE backend
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
# zero_centered_gamma is already tested in TestNorm
...
@@ -856,7 +886,7 @@ class TestFusedDense:
...
@@ -856,7 +886,7 @@ class TestFusedDense:
x
,
w
,
gamma
,
beta
x
,
w
,
gamma
,
beta
)
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_out
,
(
prim_x_grad
,
prim_x_grad
,
...
@@ -873,7 +903,7 @@ class TestFusedDense:
...
@@ -873,7 +903,7 @@ class TestFusedDense:
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
...
@@ -886,7 +916,7 @@ class TestFusedDense:
...
@@ -886,7 +916,7 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
Test layernorm_mlp VJP Rule
"""
"""
# No Norm FWD E5M2 in TE backend
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
# zero_centered_gamma is already tested in TestNorm
...
@@ -898,13 +928,13 @@ class TestFusedDense:
...
@@ -898,13 +928,13 @@ class TestFusedDense:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
kernel_1
=
jax
.
random
.
normal
(
kernel_1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
subkeys
[
1
],
(
k
,
len
(
activation_type
)
,
n
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
)
/
jnp
.
sqrt
(
k
)
kernel_2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
kernel_2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
beta
=
None
# was tested in TestNorm
beta
=
None
# was tested in TestNorm
if
use_bias
:
if
use_bias
:
bias_1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
)
bias_1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
n
),
jnp
.
bfloat16
)
bias_2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
bias_2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
else
:
else
:
bias_1
=
None
bias_1
=
None
...
@@ -963,7 +993,7 @@ class TestFusedDense:
...
@@ -963,7 +993,7 @@ class TestFusedDense:
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
range
(
6
))
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_out
,
(
prim_x_grad
,
prim_x_grad
,
...
@@ -1039,19 +1069,19 @@ class TestGroupedDense:
...
@@ -1039,19 +1069,19 @@ class TestGroupedDense:
subkeys
=
jax
.
random
.
split
(
key
,
len
(
shape_list
)
*
2
)
subkeys
=
jax
.
random
.
split
(
key
,
len
(
shape_list
)
*
2
)
lhs_list
,
rhs_list
,
contracting_dims_list
=
[],
[],
[]
lhs_list
,
rhs_list
,
contracting_dims_list
=
[],
[],
[]
for
i
,
((
m
,
n
,
k
),
layout
)
in
enumerate
(
zip
(
shape_list
,
layout_list
)):
for
i
,
((
m
,
n
,
k
),
data_
layout
)
in
enumerate
(
zip
(
shape_list
,
layout_list
)):
lhs
=
jax
.
random
.
uniform
(
lhs
=
jax
.
random
.
uniform
(
subkeys
[
2
*
i
],
subkeys
[
2
*
i
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
(
m
if
data_
layout
[
0
]
==
"N"
else
k
,
k
if
data_
layout
[
0
]
==
"N"
else
m
),
dtype
=
dtype
,
dtype
=
dtype
,
)
)
rhs
=
jax
.
random
.
uniform
(
rhs
=
jax
.
random
.
uniform
(
subkeys
[
2
*
i
+
1
],
subkeys
[
2
*
i
+
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
(
k
if
data_
layout
[
1
]
==
"N"
else
n
,
n
if
data_
layout
[
1
]
==
"N"
else
k
),
dtype
=
dtype
,
dtype
=
dtype
,
)
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
lhs_contracting_dim
=
(
1
,)
if
data_
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
rhs_contracting_dim
=
(
0
,)
if
data_
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
lhs_list
.
append
(
lhs
)
lhs_list
.
append
(
lhs
)
...
...
tests/jax/test_distributed_fused_attn.py
View file @
ab3e5a92
...
@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
...
@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
# for loss and dbias
# for loss and dbias
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
,
allgather
=
0
,
other
=
0
)
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
,
allgather
=
0
,
other
=
0
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
def
impl_test_self_attn
(
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[
pytest
.
param
((
32
,
512
,
12
,
64
),
id
=
"32-512-12-64"
),
pytest
.
param
((
32
,
1024
,
16
,
128
),
id
=
"32-1024-16-128"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_bias_type, bias_shape"
,
[
pytest
.
param
(
AttnBiasType
.
NO_BIAS
,
None
,
id
=
"NO_BIAS"
),
pytest
.
param
(
AttnBiasType
.
PRE_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"PRE_SCALE_BIAS-1HSS"
),
pytest
.
param
(
AttnBiasType
.
POST_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"POST_SCALE_BIAS-1HSS"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_mask_type"
,
[
pytest
.
param
(
AttnMaskType
.
PADDING_MASK
,
id
=
"PADDING_MASK"
),
pytest
.
param
(
AttnMaskType
.
CAUSAL_MASK
,
id
=
"CAUSAL_MASK"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_self_attn
(
self
,
self
,
device_count
,
device_count
,
mesh_shape
,
mesh_shape
,
...
@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
...
@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
bias_shape
,
bias_shape
,
attn_mask_type
,
attn_mask_type
,
dtype
,
dtype
,
use_shardy
,
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
dropout_prob
=
0.0
dropout_prob
=
0.0
is_training
=
True
is_training
=
True
...
@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
...
@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
)
)
runner
.
test_backward
()
runner
.
test_backward
()
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[
pytest
.
param
((
32
,
512
,
12
,
64
),
id
=
"32-512-12-64"
),
pytest
.
param
((
32
,
1024
,
16
,
128
),
id
=
"32-1024-16-128"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_bias_type, bias_shape"
,
[
pytest
.
param
(
AttnBiasType
.
NO_BIAS
,
None
,
id
=
"NO_BIAS"
),
pytest
.
param
(
AttnBiasType
.
PRE_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"PRE_SCALE_BIAS-1HSS"
),
pytest
.
param
(
AttnBiasType
.
POST_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"POST_SCALE_BIAS-1HSS"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_mask_type"
,
[
pytest
.
param
(
AttnMaskType
.
PADDING_MASK
,
id
=
"PADDING_MASK"
),
pytest
.
param
(
AttnMaskType
.
CAUSAL_MASK
,
id
=
"CAUSAL_MASK"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_self_attn
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_bias_type
,
bias_shape
,
attn_mask_type
,
dtype
,
):
self
.
impl_test_self_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_bias_type
,
bias_shape
,
attn_mask_type
,
dtype
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"attn_bias_type, bias_shape"
,
[
pytest
.
param
(
AttnBiasType
.
NO_BIAS
,
None
,
id
=
"NO_BIAS"
),
pytest
.
param
(
AttnBiasType
.
PRE_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"PRE_SCALE_BIAS-1HSS"
),
],
)
def
test_self_attn_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
attn_bias_type
,
bias_shape
):
data_shape
=
(
32
,
512
,
12
,
64
)
self
.
impl_test_self_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_bias_type
,
bias_shape
,
AttnMaskType
.
PADDING_MASK
,
jnp
.
bfloat16
,
use_shardy
=
True
,
)
class
TestDistributedCrossAttn
:
class
TestDistributedCrossAttn
:
...
@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
...
@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
runner
.
test_backward
()
runner
.
test_backward
()
@
pytest
.
mark
.
parametrize
(
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
=
[
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest
.
param
([
2
,
128
,
8
,
128
],
id
=
"2-128xCP-8-128"
),
pytest
.
param
([
4
,
256
,
16
,
64
],
id
=
"4-256xCP-16-64"
),
],
)
@
pytest
.
mark
.
parametrize
(
"kv_groups"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
[
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_KVPACKED-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_KVPACKED-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_SEPARATE-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_SEPARATE-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
NO_MASK
,
id
=
"HD_KVPACKED-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
NO_MASK
,
id
=
"HD_KVPACKED-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
NO_MASK
,
id
=
"BSHD_SEPARATE-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
NO_MASK
,
id
=
"BSHD_SEPARATE-NO_MASK"
),
pytest
.
param
(
pytest
.
param
(
QKVLayout
.
THD_THD_THD
,
QKVLayout
.
THD_THD_THD
,
AttnMaskType
.
PADDING_CAUSAL_MASK
,
id
=
"THD_SEPARATE-PADDING_CAUSAL"
AttnMaskType
.
PADDING_CAUSAL_MASK
,
id
=
"THD_SEPARATE-PADDING_CAUSAL"
,
),
),
],
]
)
@
pytest
.
mark
.
parametrize
(
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
=
[
"load_balanced"
,
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
[
pytest
.
param
(
True
,
id
=
"BALANCED"
),
pytest
.
param
(
False
,
id
=
"UNBALANCED"
)],
pytest
.
param
([
2
,
128
,
8
,
128
],
id
=
"2-128xCP-8-128"
),
)
pytest
.
param
([
4
,
256
,
16
,
64
],
id
=
"4-256xCP-16-64"
),
]
class
TestDistributedContextParallelSelfAttn
:
class
TestDistributedContextParallelSelfAttn
:
def
impl_test_context_parallel_attn
(
def
impl_test_context_parallel_attn
(
...
@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
qkv_layout
,
load_balanced
,
load_balanced
,
cp_strategy
,
cp_strategy
,
use_shardy
,
use_scan_ring
=
False
,
):
):
if
qkv_layout
.
is_thd
():
if
cp_strategy
==
CPStrategy
.
ALL_GATHER
:
pytest
.
skip
(
"THD doesn't support all gather context parallelism."
)
if
not
load_balanced
and
cp_strategy
==
CPStrategy
.
RING
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
assert
not
use_scan_ring
or
cp_strategy
==
CPStrategy
.
RING
if
use_scan_ring
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"1"
else
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"0"
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
attn_bias_type
=
AttnBiasType
.
NO_BIAS
attn_bias_type
=
AttnBiasType
.
NO_BIAS
bias_shape
=
None
bias_shape
=
None
dropout_prob
=
0.0
dropout_prob
=
0.0
...
@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
pytest
.
skip
(
f
"Skipping
{
kv_groups
=
}
not multiple of
{
data_shape
=
}
or
{
tp_size
=
}
"
)
pytest
.
skip
(
f
"Skipping
{
kv_groups
=
}
not multiple of
{
data_shape
=
}
or
{
tp_size
=
}
"
)
runner
.
test_backward
()
runner
.
test_backward
()
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
[:
1
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
def
test_context_parallel_allgather_attn_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_mask_type
,
dtype
,
qkv_layout
,
):
kv_groups
=
8
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
kv_groups
,
attn_mask_type
,
dtype
,
qkv_layout
,
load_balanced
=
True
,
cp_strategy
=
CPStrategy
.
ALL_GATHER
,
use_shardy
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"kv_groups"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
@
pytest
.
mark
.
parametrize
(
"load_balanced"
,
[
pytest
.
param
(
True
,
id
=
"BALANCED"
),
pytest
.
param
(
False
,
id
=
"UNBALANCED"
)],
)
def
test_context_parallel_allgather_attn
(
def
test_context_parallel_allgather_attn
(
self
,
self
,
device_count
,
device_count
,
...
@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
qkv_layout
,
load_balanced
,
load_balanced
,
):
):
if
qkv_layout
.
is_thd
():
self
.
impl_test_context_parallel_attn
(
pytest
.
skip
(
"THD doesn't support all gather context parallelism."
)
return
self
.
impl_test_context_parallel_attn
(
device_count
,
device_count
,
mesh_shape
,
mesh_shape
,
mesh_axes
,
mesh_axes
,
...
@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
qkv_layout
,
load_balanced
,
load_balanced
,
CPStrategy
.
ALL_GATHER
,
CPStrategy
.
ALL_GATHER
,
use_shardy
=
False
,
)
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"kv_groups"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
@
pytest
.
mark
.
parametrize
(
"load_balanced"
,
[
pytest
.
param
(
True
,
id
=
"BALANCED"
),
pytest
.
param
(
False
,
id
=
"UNBALANCED"
)],
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"use_scan"
,
"use_scan"
,
[
pytest
.
param
(
False
,
id
=
"NO_SCAN"
),
pytest
.
param
(
True
,
id
=
"USE_SCAN"
)],
[
pytest
.
param
(
False
,
id
=
"NO_SCAN"
),
pytest
.
param
(
True
,
id
=
"USE_SCAN"
)],
...
@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
load_balanced
,
load_balanced
,
use_scan
,
use_scan
,
):
):
if
use_scan
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"1"
else
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"0"
if
qkv_layout
.
is_thd
()
and
not
load_balanced
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
self
.
impl_test_context_parallel_attn
(
self
.
impl_test_context_parallel_attn
(
device_count
,
device_count
,
mesh_shape
,
mesh_shape
,
...
@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
...
@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
qkv_layout
,
load_balanced
,
load_balanced
,
CPStrategy
.
RING
,
CPStrategy
.
RING
,
use_shardy
=
False
,
use_scan_ring
=
use_scan
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
[:
1
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
def
test_context_parallel_ring_attn_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_mask_type
,
dtype
,
qkv_layout
,
):
kv_groups
=
8
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
kv_groups
,
attn_mask_type
,
dtype
,
qkv_layout
,
load_balanced
=
True
,
cp_strategy
=
CPStrategy
.
RING
,
use_shardy
=
False
,
use_scan_ring
=
True
,
)
)
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
return
class
TestReorderCausalLoadBalancing
:
class
TestReorderCausalLoadBalancing
:
...
...
tests/jax/test_distributed_layernorm.py
View file @
ab3e5a92
...
@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
...
@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
}
}
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
if
is_fp8_supported
:
...
@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
...
@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_layernorm
(
def
test_layernorm
(
self
,
self
,
device_count
,
device_count
,
...
@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
...
@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
zero_centered_gamma
,
zero_centered_gamma
,
shard_weights
,
shard_weights
,
fp8_recipe
,
fp8_recipe
,
use_shardy
,
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
epsilon
=
1e-6
epsilon
=
1e-6
ln_type
=
"layernorm"
ln_type
=
"layernorm"
q_dtype
=
jnp
.
float8_e4m3fn
q_dtype
=
jnp
.
float8_e4m3fn
...
@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
...
@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_rmsnorm
(
def
test_rmsnorm
(
self
,
self
,
device_count
,
device_count
,
...
@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
...
@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
dtype
,
dtype
,
shard_weights
,
shard_weights
,
fp8_recipe
,
fp8_recipe
,
use_shardy
,
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
epsilon
=
1e-6
epsilon
=
1e-6
ln_type
=
"rmsnorm"
ln_type
=
"rmsnorm"
q_dtype
=
jnp
.
float8_e4m3fn
q_dtype
=
jnp
.
float8_e4m3fn
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
ab3e5a92
...
@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
...
@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
if
is_fp8_supported
:
...
@@ -45,11 +45,17 @@ if is_mxfp8_supported:
...
@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
INPUT_SHAPE
=
[[
2
,
64
,
64
]]
# [batch, seqlen, hidden_in]
INPUT_SHAPE
=
[[
4
,
64
,
128
]]
# [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
KERNEL_1_AXES
=
(
W_FSDP_AXES
,
W_JOINED_AXES
,
W_TP_AXES
)
KERNEL_2_AXES
=
(
W_TP_AXES
,
W_FSDP_AXES
)
LN_SCALE_AXES
=
(
W_NO_SHARD_AXES
,)
LN_BIAS_AXES
=
(
W_NO_SHARD_AXES
,)
BIAS_1_AXES
=
(
W_JOINED_AXES
,
W_TP_AXES
)
BIAS_2_AXES
=
(
W_NO_SHARD_AXES
,)
INTERMEDIATE
=
64
INTERMEDIATE
=
64
...
@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
...
@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs
.
append
(
configs
.
append
(
[
2
,
(
1
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
[
2
,
(
1
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
)
)
if
is_devices_enough
(
4
):
if
is_devices_enough
(
4
):
configs
.
append
(
configs
.
append
(
[
4
,
(
2
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
[
4
,
(
2
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
...
@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
...
@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
k1
=
jax
.
random
.
normal
(
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
/
jnp
.
sqrt
(
hidden_in
)
)
/
jnp
.
sqrt
(
hidden_in
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
INTERMEDIATE
INTERMEDIATE
)
)
if
use_bias
:
if
use_bias
:
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
else
:
else
:
b1
=
None
b1
=
None
...
@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
...
@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
dot_2_input_axes
=
DOT_2_INPUT_AXES
dot_2_input_axes
=
DOT_2_INPUT_AXES
kernel_1_axes
=
KERNEL_1_AXES
kernel_2_axes
=
KERNEL_2_AXES
else
:
else
:
layernorm_input_axes
=
None
layernorm_input_axes
=
None
dot_1_input_axes
=
None
dot_1_input_axes
=
dot_2_input_axes
=
None
dot_2_input
_axes
=
None
kernel_1_axes
=
kernel_2
_axes
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
)
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
)
...
@@ -130,21 +137,17 @@ class TestDistributedLayernormMLP:
...
@@ -130,21 +137,17 @@ class TestDistributedLayernormMLP:
norm_input_axes
=
layernorm_input_axes
,
norm_input_axes
=
layernorm_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
kernel_1_axes
=
kernel_1_axes
,
kernel_2_axes
=
kernel_2_axes
,
activation_type
=
activation_type
,
activation_type
=
activation_type
,
quantizer_sets
=
quantizer_sets
,
quantizer_sets
=
quantizer_sets
,
)
)
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
def
_test_layernorm_mlp_grad
(
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_fp8_mlp_primitive
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
layernorm_type
=
"rmsnorm"
layernorm_type
=
"rmsnorm"
...
@@ -168,12 +171,12 @@ class TestDistributedLayernormMLP:
...
@@ -168,12 +171,12 @@ class TestDistributedLayernormMLP:
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
"tp"
))
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
if
use_bias
:
if
use_bias
:
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
))
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp"
))
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
else
:
else
:
b1_sharding
=
b1_
=
None
b1_sharding
=
b1_
=
None
...
@@ -248,9 +251,59 @@ class TestDistributedLayernormMLP:
...
@@ -248,9 +251,59 @@ class TestDistributedLayernormMLP:
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_grad_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
=
recipe
.
DelayedScaling
(),
use_shardy
=
True
,
)
def
_test_layernorm_mlp
(
def
_test_layernorm_mlp
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
=
None
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
,
use_shardy
,
):
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
batch
,
seqlen
,
hidden_in
=
input_shape
batch
,
seqlen
,
hidden_in
=
input_shape
layernorm_type
=
"rmsnorm"
layernorm_type
=
"rmsnorm"
...
@@ -269,7 +322,7 @@ class TestDistributedLayernormMLP:
...
@@ -269,7 +322,7 @@ class TestDistributedLayernormMLP:
activations
=
activation_type
,
activations
=
activation_type
,
use_bias
=
use_bias
,
use_bias
=
use_bias
,
)
)
params_single
=
ln_mlp_single
.
init
(
init_rngs
,
x
)
params_single
=
ln_mlp_single
.
init
(
init_rngs
,
x
,
deterministic
=
True
)
mlp_out_single
,
ln_out_single
=
ln_mlp_single
.
apply
(
mlp_out_single
,
ln_out_single
=
ln_mlp_single
.
apply
(
params_single
,
x
,
deterministic
=
True
params_single
,
x
,
deterministic
=
True
)
)
...
@@ -286,19 +339,19 @@ class TestDistributedLayernormMLP:
...
@@ -286,19 +339,19 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence
=
False
,
transpose_batch_sequence
=
False
,
intermediate_dim
=
INTERMEDIATE
,
intermediate_dim
=
INTERMEDIATE
,
activations
=
activation_type
,
activations
=
activation_type
,
scale_axes
=
(
W_NO_SHARD
_AXES
,
),
scale_axes
=
LN_SCALE
_AXES
,
ln_bias_axes
=
(
W_NO_SHARD
_AXES
,
),
ln_bias_axes
=
LN_BIAS
_AXES
,
kernel_axes_1
=
(
W_FSDP_AXES
,
W_JOINED_AXES
,
W_TP
_AXES
)
,
kernel_axes_1
=
KERNEL_1
_AXES
,
kernel_axes_2
=
(
W_TP_AXES
,
W_FSDP
_AXES
)
,
kernel_axes_2
=
KERNEL_2
_AXES
,
use_bias
=
use_bias
,
use_bias
=
use_bias
,
bias_axes_1
=
(
W_JOINED_AXES
,
W_TP
_AXES
)
,
bias_axes_1
=
BIAS_1
_AXES
,
bias_axes_2
=
(
W_NO_SHARD
_AXES
,
),
bias_axes_2
=
BIAS_2
_AXES
,
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
,
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
,
dot_1_input_axes
=
DOT_1_INPUT_AXES
,
dot_1_input_axes
=
DOT_1_INPUT_AXES
,
dot_2_input_axes
=
DOT_2_INPUT_AXES
,
dot_2_input_axes
=
DOT_2_INPUT_AXES
,
name
=
"mlp"
,
name
=
"mlp"
,
)
)
params_sharded
=
ln_mlp_sharded
.
init
(
init_rngs
,
x
)
params_sharded
=
ln_mlp_sharded
.
init
(
init_rngs
,
x
,
deterministic
=
True
)
mlp_out_sharded
,
ln_out_sharded
=
ln_mlp_sharded
.
apply
(
mlp_out_sharded
,
ln_out_sharded
=
ln_mlp_sharded
.
apply
(
params_sharded
,
x
,
deterministic
=
True
params_sharded
,
x
,
deterministic
=
True
)
)
...
@@ -313,25 +366,38 @@ class TestDistributedLayernormMLP:
...
@@ -313,25 +366,38 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_shardy
):
self
.
_test_layernorm_mlp
(
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
,
fp8_recipe
=
None
,
use_shardy
=
use_shardy
,
)
)
# TODO: debug
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
# @pytest_parametrize_wrapper(
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
# "activation_type", [("gelu",), ("gelu", "linear")]
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
# )
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
# @pytest_parametrize_wrapper("use_bias", [True, False])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
def
test_layernorm_mlp_layer_fp8
(
# @pytest_parametrize_wrapper("dtype", DTYPES)
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
):
# def test_layernorm_fp8_mlp_layer(
self
.
_test_layernorm_mlp
(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
mesh_config
,
# ):
activation_type
,
# self._test_layernorm_mlp(
use_bias
,
# mesh_config, activation_type, use_bias, input_shape, dtype,
input_shape
,
# use_fp8=True, fp8_recipe=fp8_recipe
dtype
,
# )
use_fp8
=
True
,
fp8_recipe
=
fp8_recipe
,
use_shardy
=
False
,
)
tests/jax/test_distributed_softmax.py
View file @
ab3e5a92
...
@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
...
@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes
=
4
# 1 * FP32
all_reduce_loss_bytes
=
4
# 1 * FP32
return
generate_collectives_count
(
allreduce
=
all_reduce_loss_bytes
,
allgather
=
0
,
other
=
0
)
return
generate_collectives_count
(
allreduce
=
all_reduce_loss_bytes
,
allgather
=
0
,
other
=
0
)
def
generate_inputs
(
self
,
shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
):
def
generate_inputs
(
self
,
shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
):
batch
,
_
,
sqelen
,
_
=
shape
batch
,
_
,
sqelen
,
_
=
shape
x
=
random
.
normal
(
random
.
PRNGKey
(
1124
),
shape
,
dtype
=
dtype
)
x
=
random
.
normal
(
random
.
PRNGKey
(
1124
),
shape
,
dtype
=
dtype
)
if
softmax_type
==
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
if
softmax_type
==
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
mask
=
make_causal_mask
(
batch
,
sqelen
)
mask
=
make_causal_mask
(
batch
,
sqelen
)
else
:
else
:
mask
=
make_self_mask
(
batch
,
sqelen
)
mask
=
make_self_mask
(
1
if
broadcast_batch_mask
else
batch
,
sqelen
)
if
not
bad_sharding
:
if
not
bad_sharding
:
x_pspec
=
PartitionSpec
(
x_pspec
=
PartitionSpec
(
...
@@ -45,6 +47,10 @@ class TestDistributedSoftmax:
...
@@ -45,6 +47,10 @@ class TestDistributedSoftmax:
x_pspec
=
PartitionSpec
(
x_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
mesh_resource
.
tp_resource
mesh_resource
.
dp_resource
,
None
,
None
,
mesh_resource
.
tp_resource
)
)
if
broadcast_batch_mask
:
mask_pspec
=
PartitionSpec
(
None
,
None
,
None
,
None
)
else
:
mask_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
None
)
mask_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
None
)
return
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
return
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
...
@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
...
@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
output
=
jax
.
nn
.
softmax
(
x
*
scale_factor
)
output
=
jax
.
nn
.
softmax
(
x
*
scale_factor
)
return
jnp
.
mean
(
output
)
return
jnp
.
mean
(
output
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
def
impl_test_softmax
(
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
12
,
128
,
128
],
[
64
,
16
,
1024
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
,
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
1.0
,
3.0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
def
test_softmax
(
self
,
self
,
device_count
,
device_count
,
mesh_shape
,
mesh_shape
,
...
@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
...
@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
scale_factor
,
scale_factor
,
dtype
,
dtype
,
bad_sharding
,
bad_sharding
,
broadcast_batch_mask
,
use_shardy
,
):
):
if
broadcast_batch_mask
and
softmax_type
!=
SoftmaxType
.
SCALED_MASKED
:
pytest
.
skip
(
"Softmax type has no mask."
)
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
target_func
=
partial
(
target_func
=
partial
(
self
.
target_func
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
self
.
target_func
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
)
)
ref_func
=
partial
(
self
.
ref_func
,
scale_factor
=
scale_factor
,
dtype
=
dtype
)
ref_func
=
partial
(
self
.
ref_func
,
scale_factor
=
scale_factor
,
dtype
=
dtype
)
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
=
self
.
generate_inputs
(
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
=
self
.
generate_inputs
(
data_shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
data_shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
)
)
collective_count_ref
=
self
.
generate_collectives_count_ref
()
collective_count_ref
=
self
.
generate_collectives_count_ref
()
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
...
@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
...
@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
assert
"Sharding the hidden dimension is not supported"
in
str
(
w
),
(
assert
"Sharding the hidden dimension is not supported"
in
str
(
w
),
(
"Softmax primitive did not raise the correct warning for "
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
"unsupported sharding in the hidden dimension."
f
"
{
str
(
w
)
}
"
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
12
,
128
,
128
],
[
64
,
16
,
1024
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
,
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
1.0
,
3.0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"broadcast_batch_mask"
,
[
False
,
True
])
def
test_softmax
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
scale_factor
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
):
self
.
impl_test_softmax
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
scale_factor
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
])
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"broadcast_batch_mask"
,
[
False
,
True
])
def
test_softmax_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
softmax_type
,
bad_sharding
,
broadcast_batch_mask
,
):
self
.
impl_test_softmax
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
=
[
32
,
12
,
128
,
128
],
softmax_type
=
softmax_type
,
scale_factor
=
1.0
,
dtype
=
DTYPES
[
0
],
bad_sharding
=
bad_sharding
,
broadcast_batch_mask
=
broadcast_batch_mask
,
use_shardy
=
True
,
)
)
tests/jax/test_layer.py
View file @
ab3e5a92
...
@@ -39,7 +39,7 @@ def enable_fused_attn():
...
@@ -39,7 +39,7 @@ def enable_fused_attn():
is_fp8_supported
,
reason
=
is_fp8_available
()
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
QUANTIZE_RECIPES
=
[]
QUANTIZE_RECIPES
=
[]
""" Find supported scaling modes"""
""" Find supported scaling modes"""
...
@@ -215,12 +215,53 @@ ATTRS = [
...
@@ -215,12 +215,53 @@ ATTRS = [
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
},
# attrs22
# attrs22
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"causal"
,
_KEY_OF_WINDOW_SIZE
:
None
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs23
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"causal"
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs24
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"no_mask"
,
},
# attrs25
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"no_mask"
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
# attrs26
{
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
},
# attrs27
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_WINDOW_SIZE
:
None
,
},
# attrs28
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
]
]
ATTRS
=
[{
**
BASE_ATTRS
,
**
attr
}
for
attr
in
ATTRS
]
ATTRS
=
[{
**
BASE_ATTRS
,
**
attr
}
for
attr
in
ATTRS
]
...
@@ -313,7 +354,7 @@ class BaseRunner:
...
@@ -313,7 +354,7 @@ class BaseRunner:
test_others
,
test_others
,
test_layer
,
test_layer
,
)
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
updated_state
[
0
],
QuantizeConfig
.
COLLECTION_NAME
updated_state
[
0
],
QuantizeConfig
.
COLLECTION_NAME
)
)
...
@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
...
@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
data_rng
=
jax
.
random
.
PRNGKey
(
2024
)
data_rng
=
jax
.
random
.
PRNGKey
(
2024
)
inputs
=
(
jax
.
random
.
normal
(
data_rng
,
data_shape
,
dtype
),)
inputs
=
(
jax
.
random
.
normal
(
data_rng
,
data_shape
,
dtype
),)
padded_mask
=
jnp
.
zeros
((
batch
,
1
,
seqlen
,
seqlen
),
dtype
=
jnp
.
uint8
)
mask_shape
=
(
batch
,
1
,
seqlen
,
seqlen
)
causal_mask
=
jnp
.
triu
(
jnp
.
ones
((
batch
,
1
,
seqlen
,
seqlen
),
dtype
=
jnp
.
uint8
),
k
=
1
)
padded_mask
=
jnp
.
zeros
(
mask_shape
,
dtype
=
jnp
.
uint8
)
causal_mask
=
jnp
.
triu
(
jnp
.
ones
(
mask_shape
,
dtype
=
jnp
.
uint8
),
k
=
1
)
if
self
.
attrs
[
_KEY_OF_SELF_ATTN_MASK_TYPE
]
in
[
"causal"
,
"padding_causal"
]:
if
self
.
attrs
[
_KEY_OF_SELF_ATTN_MASK_TYPE
]
in
[
"causal"
,
"padding_causal"
]:
mask
=
causal_mask
mask
=
causal_mask
else
:
else
:
mask
=
padded_mask
mask
=
padded_mask
ref_masks
=
(
1
-
mask
,)
ref_masks
=
(
1
-
mask
,)
test_masks
=
(
None
,
mask
)
# The second arg of Transformer is encoded tokens.
test_masks
=
(
None
,
mask
)
# The second arg of Transformer is encoded tokens.
...
...
tests/jax/test_softmax.py
View file @
ab3e5a92
...
@@ -18,6 +18,7 @@ from utils import assert_allclose
...
@@ -18,6 +18,7 @@ from utils import assert_allclose
from
transformer_engine.jax.cpp_extensions
import
is_softmax_kernel_available
from
transformer_engine.jax.cpp_extensions
import
is_softmax_kernel_available
from
transformer_engine.jax.softmax
import
SoftmaxType
,
softmax
from
transformer_engine.jax.softmax
import
SoftmaxType
,
softmax
from
transformer_engine.jax.flax.module
import
Softmax
def
catch_unsupported
(
method
):
def
catch_unsupported
(
method
):
...
@@ -94,7 +95,6 @@ class SoftmaxRunner:
...
@@ -94,7 +95,6 @@ class SoftmaxRunner:
case
_
:
case
_
:
raise
ValueError
(
f
"Unknown
{
self
.
softmax_type
=
}
"
)
raise
ValueError
(
f
"Unknown
{
self
.
softmax_type
=
}
"
)
@
catch_unsupported
def
test_forward
(
self
):
def
test_forward
(
self
):
"""
"""
Test transformer_engine.jax.softmax.softmax fwd rule
Test transformer_engine.jax.softmax.softmax fwd rule
...
@@ -104,7 +104,6 @@ class SoftmaxRunner:
...
@@ -104,7 +104,6 @@ class SoftmaxRunner:
reference_out
=
__class__
.
reference_softmax
(
self
.
logits
,
self
.
mask
,
self
.
scale_factor
)
reference_out
=
__class__
.
reference_softmax
(
self
.
logits
,
self
.
mask
,
self
.
scale_factor
)
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
self
.
dtype
)
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
self
.
dtype
)
@
catch_unsupported
def
test_backward
(
self
):
def
test_backward
(
self
):
"""
"""
Test transformer_engine.jax.softmax.softmax bwd rule
Test transformer_engine.jax.softmax.softmax bwd rule
...
@@ -141,6 +140,50 @@ class SoftmaxRunner:
...
@@ -141,6 +140,50 @@ class SoftmaxRunner:
assert_allclose
(
primitive_grad_logits
,
reference_grad_logits
,
dtype
=
self
.
dtype
)
assert_allclose
(
primitive_grad_logits
,
reference_grad_logits
,
dtype
=
self
.
dtype
)
class
SoftmaxPrimitivesRunner
(
SoftmaxRunner
):
"""
Jax Softmax Primitives runner
"""
@
catch_unsupported
def
test_forward
(
self
):
return
super
().
test_forward
()
@
catch_unsupported
def
test_backward
(
self
):
return
super
().
test_backward
()
class
SoftmaxModuleRunner
:
"""
Jax Softmax Module runner
"""
module_runner
:
SoftmaxRunner
bias
:
None
def
__init__
(
self
,
module_runner
,
bias
):
self
.
module_runner
=
module_runner
self
.
bias
=
bias
def
test_forward
(
self
):
"""
Test transformer_engine.jax.flax.module.Softmax fwd rule
"""
runner
=
self
.
module_runner
runner
.
_setup_inputs
()
rng
=
jax
.
random
.
PRNGKey
(
0
)
softmax_module
=
Softmax
(
scale_factor
=
runner
.
scale_factor
,
softmax_type
=
runner
.
softmax_type
,
)
softmax_vars
=
softmax_module
.
init
(
rng
,
runner
.
logits
,
runner
.
mask
)
module_out
=
softmax_module
.
apply
(
softmax_vars
,
runner
.
logits
,
runner
.
mask
)
reference_out
=
runner
.
reference_softmax
(
runner
.
logits
,
runner
.
mask
,
runner
.
scale_factor
)
assert_allclose
(
module_out
,
reference_out
,
dtype
=
runner
.
dtype
)
# Run softmax primitives test
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"b, s_q, s_kv, h"
,
"b, s_q, s_kv, h"
,
[
[
...
@@ -165,7 +208,7 @@ class SoftmaxRunner:
...
@@ -165,7 +208,7 @@ class SoftmaxRunner:
pytest
.
param
(
jnp
.
float16
,
id
=
"FP16"
),
pytest
.
param
(
jnp
.
float16
,
id
=
"FP16"
),
],
],
)
)
class
TestSoftmax
:
class
TestSoftmax
Primitives
:
"""
"""
Test transformer_engine.jax.softmax.softmax
Test transformer_engine.jax.softmax.softmax
"""
"""
...
@@ -175,7 +218,7 @@ class TestSoftmax:
...
@@ -175,7 +218,7 @@ class TestSoftmax:
"""
"""
Test forward with parameterized configs
Test forward with parameterized configs
"""
"""
runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
=
Softmax
Primitives
Runner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
.
test_forward
()
runner
.
test_forward
()
@
staticmethod
@
staticmethod
...
@@ -183,5 +226,48 @@ class TestSoftmax:
...
@@ -183,5 +226,48 @@ class TestSoftmax:
"""
"""
Test forward with parameterized configs
Test forward with parameterized configs
"""
"""
runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
=
Softmax
Primitives
Runner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
.
test_backward
()
runner
.
test_backward
()
# Run Softmax module test
@
pytest
.
mark
.
parametrize
(
"b, s_q, s_kv, h"
,
[
pytest
.
param
(
8
,
16
,
16
,
16
,
id
=
"8-16-16-16"
),
pytest
.
param
(
8
,
512
,
512
,
16
,
id
=
"8-512-512-16"
),
pytest
.
param
(
2
,
8
,
16384
,
8
,
id
=
"2-8-16384-8"
),
# triggers backup framework implementation due to (s_q % 4) != 0
pytest
.
param
(
8
,
511
,
512
,
16
,
id
=
"8-511-512-16"
),
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
0.125
])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
pytest
.
param
(
SoftmaxType
.
SCALED
,
id
=
"SCALED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_MASKED
,
id
=
"SCALED_MASKED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
id
=
"SCALED_UPPER_TRIANG_MASKED"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
),
pytest
.
param
(
jnp
.
float16
,
id
=
"FP16"
),
],
)
class
TestSoftmaxModule
:
"""
Test transformer_engine.jax.flax.module.Softmax
"""
@
staticmethod
def
test_forward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
):
"""
Test forward with parameterized configs
"""
module_runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
bias
=
None
runner
=
SoftmaxModuleRunner
(
module_runner
,
bias
)
runner
.
test_forward
()
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
View file @
ab3e5a92
...
@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import (
...
@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import (
)
)
import
transformer_engine.pytorch
as
te
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Tensor
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
def
_get_raw_data
(
quantized_tensor
):
def
_get_raw_data
(
quantized_tensor
):
...
@@ -228,6 +232,273 @@ class MiniOptimizer:
...
@@ -228,6 +232,273 @@ class MiniOptimizer:
weight
.
data
.
copy_
(
master_weight
)
weight
.
data
.
copy_
(
master_weight
)
class
MiniFSDP
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
# Flatten the weights and pad to align with world size
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
if
isinstance
(
w
,
QuantizedTensor
)
else
w
.
view
(
-
1
)
for
w
in
weights
]
if
isinstance
(
weights
[
0
],
QuantizedTensor
):
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
for
w
in
weights
]
else
:
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
# Split flattened weights into shards
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_weight_shard
)
shard_size
=
self
.
flatten_weight
.
size
(
0
)
//
world_size
# Map original tensors to flattened indices
tensor_indices
=
[]
cumulative_length
=
0
for
tensor
in
raw_data_list
:
length
=
tensor
.
size
(
0
)
tensor_indices
.
append
((
cumulative_length
,
cumulative_length
+
length
))
cumulative_length
+=
length
# Build shard index mappings
self
.
weight_indices
=
[]
self
.
shard_indices
=
[]
for
idx
,
(
start
,
end
)
in
enumerate
(
tensor_indices
):
shard_start
=
rank
*
shard_size
shard_end
=
shard_start
+
shard_size
adjusted_end
=
min
(
shard_end
,
original_length
)
if
start
<=
adjusted_end
and
end
>=
shard_start
:
start_idx
=
max
(
start
,
shard_start
)
end_idx
=
min
(
end
,
adjusted_end
)
self
.
weight_indices
.
append
((
start_idx
-
start
,
end_idx
-
start
))
self
.
shard_indices
.
append
((
start_idx
-
shard_start
,
end_idx
-
shard_start
))
else
:
self
.
weight_indices
.
append
((
None
,
None
))
self
.
shard_indices
.
append
((
None
,
None
))
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
else
:
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
# Initialize local model weights and high-precision master weights
self
.
local_weights
=
[]
self
.
master_weights
=
[]
for
i
,
weight
in
enumerate
(
self
.
weights
):
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
shard_start
,
shard_end
=
self
.
shard_indices
[
i
]
if
shard_start
is
not
None
and
shard_end
is
not
None
:
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
self
.
local_weights
.
append
(
local_weight_shard
)
if
isinstance
(
weight
,
QuantizedTensor
):
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight_shard
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
weight_start
:
weight_end
]
else
:
master_weight_shard
=
weight
.
detach
().
view
(
-
1
).
float
()[
weight_start
:
weight_end
]
self
.
master_weights
.
append
(
master_weight_shard
)
else
:
self
.
local_weights
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
setattr
(
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
def
_flatten_tensors_with_pad
(
self
,
tensors
):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size
=
dist
.
get_world_size
(
self
.
dp_group
)
flatten_tensor
=
torch
.
cat
(
tensors
)
original_length
=
flatten_tensor
.
size
(
0
)
padding_needed
=
(
world_size
-
original_length
%
world_size
)
%
world_size
if
padding_needed
>
0
:
flatten_tensor
=
torch
.
cat
(
[
flatten_tensor
,
torch
.
zeros
(
padding_needed
,
dtype
=
flatten_tensor
.
dtype
)]
)
return
flatten_tensor
,
original_length
def
zero_grad
(
self
):
for
weight
in
self
.
weights
:
weight
.
grad
=
None
weight
.
main_grad
.
zero_
()
def
step
(
self
):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer
,
_
=
self
.
_flatten_tensors_with_pad
(
[
weight
.
main_grad
.
view
(
-
1
)
for
weight
in
self
.
weights
]
)
main_grad_buffer
=
main_grad_buffer
.
to
(
self
.
local_main_grad_shard
.
dtype
)
dist
.
reduce_scatter_tensor
(
self
.
local_main_grad_shard
,
main_grad_buffer
,
group
=
self
.
dp_group
)
# Step 2: Update the master weights
for
weight
,
master_weight
,
(
shard_start
,
shard_end
)
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
shard_indices
):
if
master_weight
is
None
:
continue
# Extract the local gradient shard for this weight
grad
=
self
.
local_main_grad_shard
[
shard_start
:
shard_end
]
# Update the master weight using gradient descent
master_weight
-=
grad
*
self
.
lr
# Step 3: Cast master weights to FP8 or BF16 precision
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
local_weights
=
[]
for
local_weight
in
self
.
local_weights
:
if
local_weight
is
None
:
local_weights
.
append
(
None
)
continue
local_weights
.
append
(
local_weight
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
[
idx
[
0
]
for
idx
in
self
.
weight_indices
],
self
.
dp_group
,
local_weights
,
)
else
:
for
weight
,
master_weight
in
zip
(
self
.
local_weights
,
self
.
master_weights
):
if
master_weight
is
None
:
continue
# Copy updated master weights to local weights
weight
.
data
.
copy_
(
master_weight
)
# Step 4: All-gather updated weights across processes
dist
.
all_gather_into_tensor
(
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
)
def
_test_fsdp_cast_master_weights_to_fp8
(
quantization
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
# Configuration constants
NUM_STEPS
=
100
SEED
=
12345
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed
(
SEED
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
False
,
}
# Create model with FP8 weights
with
te
.
fp8
.
fp8_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
optimizer_fp8
=
MiniFSDP
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
optimizer_fp8
.
zero_grad
()
optimizer
.
zero_grad
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
fp8
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
print
(
f
"✅ Successfully validated FSDP
{
NUM_STEPS
}
training steps with"
f
"
{
quantization
}
quantization"
)
def
_test_zero_1
(
dp_group
):
def
_test_zero_1
(
dp_group
):
"""Make sure the implementation of zero-1 optimizer is correct"""
"""Make sure the implementation of zero-1 optimizer is correct"""
rank
=
dist
.
get_rank
(
dp_group
)
rank
=
dist
.
get_rank
(
dp_group
)
...
@@ -389,6 +660,7 @@ def main(argv=None, namespace=None):
...
@@ -389,6 +660,7 @@ def main(argv=None, namespace=None):
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
_test_zero_1
(
dp_group
)
_test_zero_1
(
dp_group
)
_test_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
_test_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
_test_fsdp_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
return
0
return
0
...
...
tests/pytorch/distributed/run_numerics.py
View file @
ab3e5a92
...
@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import (
...
@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling
,
MXFP8BlockScaling
,
DelayedScaling
,
DelayedScaling
,
Float8CurrentScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Format
,
Format
,
Recipe
,
Recipe
,
)
)
...
@@ -50,6 +51,8 @@ def quantization_recipe() -> Recipe:
...
@@ -50,6 +51,8 @@ def quantization_recipe() -> Recipe:
return
MXFP8BlockScaling
()
return
MXFP8BlockScaling
()
if
QUANTIZATION
==
"fp8_cs"
:
if
QUANTIZATION
==
"fp8_cs"
:
return
Float8CurrentScaling
()
return
Float8CurrentScaling
()
if
QUANTIZATION
==
"fp8_block_scaling"
:
return
Float8BlockScaling
()
return
te
.
fp8
.
get_default_fp8_recipe
()
return
te
.
fp8
.
get_default_fp8_recipe
()
...
@@ -86,7 +89,7 @@ def main(argv=None, namespace=None):
...
@@ -86,7 +89,7 @@ def main(argv=None, namespace=None):
# Quantization scheme
# Quantization scheme
QUANTIZATION
=
args
.
quantization
QUANTIZATION
=
args
.
quantization
if
QUANTIZATION
in
(
"fp8"
,
"mxfp8"
):
if
QUANTIZATION
in
(
"fp8"
,
"mxfp8"
,
"fp8_block_scaling"
):
global
SEQ_LEN
,
BATCH_SIZE
,
HIDDEN_SIZE
global
SEQ_LEN
,
BATCH_SIZE
,
HIDDEN_SIZE
SEQ_LEN
=
32
SEQ_LEN
=
32
BATCH_SIZE
=
32
BATCH_SIZE
=
32
...
@@ -298,6 +301,11 @@ def _loss_backward(output_single_node, output_distributed):
...
@@ -298,6 +301,11 @@ def _loss_backward(output_single_node, output_distributed):
LOSS_FN
(
output_distributed
,
target
).
backward
()
LOSS_FN
(
output_distributed
,
target
).
backward
()
def
_loss_backward_dw
(
model_single_node
,
model_distributed
):
model_single_node
.
backward_dw
()
model_distributed
.
backward_dw
()
def
_alloc_main_grad
(
model_single_node
,
model_distributed
):
def
_alloc_main_grad
(
model_single_node
,
model_distributed
):
for
model
in
[
model_single_node
,
model_distributed
]:
for
model
in
[
model_single_node
,
model_distributed
]:
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
...
@@ -471,6 +479,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
...
@@ -471,6 +479,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# Compute loss and backpropagate
# Compute loss and backpropagate
_loss_backward
(
output_single_node
,
output_distributed
)
_loss_backward
(
output_single_node
,
output_distributed
)
# Compute delayed weight gradient
if
"delay_wgrad_compute"
in
kwargs
:
_loss_backward_dw
(
model_single_node
,
model_distributed
)
# Validate outputs and gradients
# Validate outputs and gradients
_check_outputs
(
output_single_node
,
output_distributed
)
_check_outputs
(
output_single_node
,
output_distributed
)
...
@@ -492,6 +504,7 @@ def test_linear():
...
@@ -492,6 +504,7 @@ def test_linear():
{
"fuse_wgrad_accumulation"
:
True
},
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
{
"return_bias"
:
True
},
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
},
{
"delay_wgrad_compute"
:
True
},
]
]
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
,
"row"
]:
for
parallel_mode
in
[
"column"
,
"row"
]:
...
@@ -643,6 +656,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
...
@@ -643,6 +656,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Compute loss and backpropagate
# Compute loss and backpropagate
_loss_backward
(
output_single_node
,
output_distributed
)
_loss_backward
(
output_single_node
,
output_distributed
)
# Compute delayed weight gradient
if
"delay_wgrad_compute"
in
kwargs
:
_loss_backward_dw
(
model_single_node
,
model_distributed
)
# Validate outputs and gradients
# Validate outputs and gradients
_check_outputs
(
output_single_node
,
output_distributed
)
_check_outputs
(
output_single_node
,
output_distributed
)
...
@@ -665,6 +682,7 @@ def test_layernorm_linear():
...
@@ -665,6 +682,7 @@ def test_layernorm_linear():
{
"params_dtype"
:
torch
.
float16
},
{
"params_dtype"
:
torch
.
float16
},
{
"zero_centered_gamma"
:
False
},
{
"zero_centered_gamma"
:
False
},
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
]
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
]:
for
parallel_mode
in
[
"column"
]:
...
@@ -744,6 +762,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
...
@@ -744,6 +762,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Compute loss and backpropagate
# Compute loss and backpropagate
_loss_backward
(
output_single_node
,
output_distributed
)
_loss_backward
(
output_single_node
,
output_distributed
)
if
"delay_wgrad_compute"
in
kwargs
:
_loss_backward_dw
(
model_single_node
,
model_distributed
)
# Validate outputs and gradients
# Validate outputs and gradients
_check_outputs
(
output_single_node
,
output_distributed
)
_check_outputs
(
output_single_node
,
output_distributed
)
...
@@ -769,6 +790,7 @@ def test_layernorm_mlp():
...
@@ -769,6 +790,7 @@ def test_layernorm_mlp():
{
"fuse_wgrad_accumulation"
:
True
},
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
{
"return_bias"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
]
for
kwargs
in
kwargs_list
:
for
kwargs
in
kwargs_list
:
...
...
tests/pytorch/distributed/test_numerics.py
View file @
ab3e5a92
...
@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2:
...
@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2:
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
4
,
torch
.
cuda
.
device_count
())
NUM_PROCS
:
int
=
min
(
4
,
torch
.
cuda
.
device_count
())
...
@@ -48,7 +51,7 @@ def _run_test(quantization):
...
@@ -48,7 +51,7 @@ def _run_test(quantization):
all_boolean
=
[
True
,
False
]
all_boolean
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
])
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
])
def
test_distributed
(
quantization
):
def
test_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
...
@@ -56,4 +59,6 @@ def test_distributed(quantization):
...
@@ -56,4 +59,6 @@ def test_distributed(quantization):
pytest
.
skip
(
reason_for_no_fp8
)
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"fp8_block_scaling"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
_run_test
(
quantization
)
_run_test
(
quantization
)
tests/pytorch/references/blockwise_fp8_gemm_reference.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Tuple
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
fused_fma_kernel
(
y_ptr
,
x_ptr
,
s_ptr
,
M
,
N
,
y_str0
,
y_str1
,
BLOCK
:
tl
.
constexpr
=
128
):
pid
=
tl
.
program_id
(
0
)
idx
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
idx
<
M
*
N
row
=
idx
//
N
col
=
idx
%
N
y_offset
=
row
*
y_str0
+
col
*
y_str1
x_offset
=
row
*
N
+
col
s_offset
=
row
*
N
+
col
y
=
tl
.
load
(
y_ptr
+
y_offset
,
mask
=
mask
)
x
=
tl
.
load
(
x_ptr
+
x_offset
,
mask
=
mask
)
s
=
tl
.
load
(
s_ptr
+
s_offset
,
mask
=
mask
)
tl
.
store
(
y_ptr
+
y_offset
,
tl
.
fma
(
x
,
s
,
y
),
mask
=
mask
)
def
fused_fma
(
y
,
x
,
s
,
BLOCK
=
128
):
"""
Fused multiply-add operation (y = y + x * s).
PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation).
This function also supports cases where 'y' is non-contiguous in memory.
"""
assert
(
y
.
shape
==
x
.
shape
==
s
.
shape
and
y
.
dim
()
==
2
),
"All tensors must be 2D with the same shape"
assert
x
.
is_contiguous
()
and
s
.
is_contiguous
(),
"x and s must be contiguous"
M
,
N
=
y
.
shape
grid
=
((
M
*
N
+
BLOCK
-
1
)
//
BLOCK
,)
fused_fma_kernel
[
grid
](
y
,
x
,
s
,
M
,
N
,
*
y
.
stride
(),
BLOCK
)
return
y
class
CuBLASRefBlockwiseGemm
:
"""
A cuBLAS compatible reference implementation of subchannel GEMM.
"""
def
qgemm
(
self
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
demunged_sx
:
torch
.
Tensor
,
demunged_sw
:
torch
.
Tensor
,
quant_tile_shape_x
:
Tuple
[
int
,
int
],
quant_tile_shape_w
:
Tuple
[
int
,
int
],
bias
:
torch
.
Tensor
|
None
=
None
,
out
:
torch
.
Tensor
|
None
=
None
,
accumulate
:
bool
=
False
,
use_split_accumulator
:
bool
=
False
,
)
->
torch
.
Tensor
:
# demunge scale shapes for cuBLAS
is_a_1d_scaled
=
quant_tile_shape_x
[
0
]
==
1
is_b_1d_scaled
=
quant_tile_shape_w
[
0
]
==
1
M
,
K
=
qx
.
shape
N
,
K
=
qw
.
shape
# mm_tile_shape = (tile_m, tile_n, tile_k)
mm_tile_shape
=
(
quant_tile_shape_x
[
0
],
quant_tile_shape_w
[
0
],
quant_tile_shape_w
[
1
],
)
if
bias
is
not
None
and
bias
.
numel
():
# To match cuBLAS more closely when bias is applied,
# the reference accumulates into float32, and cast to
# bfloat16 is deferred until after the GEMM.
out_dtype_for_ref
=
torch
.
float32
else
:
out_dtype_for_ref
=
out_dtype
y
=
self
.
qgemm_blockwise_2d
(
qx
,
qw
,
out_dtype_for_ref
,
demunged_sx
,
demunged_sw
,
mm_tile_shape
,
use_split_accumulator
,
is_a_1d_scaled
,
is_b_1d_scaled
,
)
if
bias
is
not
None
and
bias
.
numel
():
y
+=
bias
y
=
y
.
to
(
dtype
=
out_dtype
)
# cublas accumulation first convert to output dtype, then accumulate.
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
@
classmethod
def
qgemm_blockwise_2d
(
cls
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
sx
:
torch
.
Tensor
,
sw
:
torch
.
Tensor
,
mm_tile_shape
:
Tuple
[
int
,
int
,
int
],
use_split_accumulator
:
bool
,
is_a_1d_scaled
:
bool
,
is_b_1d_scaled
:
bool
,
)
->
torch
.
Tensor
:
"""
Difference between cuBLAS and CUTLASS GEMM implementations:
- cuBLAS accumulation equation: use different equation for each scaling mode.
- For accumulation C in epiloge, it first convert C to output dtype, then accumulate.
"""
M
,
K
=
qx
.
shape
N
,
K_w
=
qw
.
shape
assert
K
==
K_w
,
"K dimension mismatch between qx and qw"
tile_len
=
128
# Calculate grid sizes without padding
grid_m
=
(
M
+
tile_len
-
1
)
//
tile_len
grid_n
=
(
N
+
tile_len
-
1
)
//
tile_len
grid_k
=
(
K
+
tile_len
-
1
)
//
tile_len
block_m
,
block_n
,
block_k
=
mm_tile_shape
scale_m_per_tile
=
tile_len
//
block_m
scale_n_per_tile
=
tile_len
//
block_n
assert
block_k
==
tile_len
,
"block_k must be equal to tile_len"
# Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM:
# 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers.
# 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale
y
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
# Validate shapes of sx and sw
scale_m_per_tensor
=
(
M
+
block_m
-
1
)
//
block_m
scale_n_per_tensor
=
(
N
+
block_n
-
1
)
//
block_n
assert
sx
.
shape
==
(
scale_m_per_tensor
,
grid_k
,
),
f
"sx shape mismatch: expected (
{
scale_m_per_tensor
}
,
{
grid_k
}
), got
{
sx
.
shape
}
"
assert
sw
.
shape
==
(
scale_n_per_tensor
,
grid_k
,
),
f
"sw shape mismatch: expected (
{
scale_n_per_tensor
}
,
{
grid_k
}
), got
{
sw
.
shape
}
"
for
i
in
range
(
grid_m
):
m_start
=
i
*
tile_len
m_end
=
min
(
m_start
+
tile_len
,
M
)
m_size
=
m_end
-
m_start
for
j
in
range
(
grid_n
):
n_start
=
j
*
tile_len
n_end
=
min
(
n_start
+
tile_len
,
N
)
n_size
=
n_end
-
n_start
y_block
=
y
[
m_start
:
m_end
,
n_start
:
n_end
]
for
k
in
range
(
grid_k
):
k_start
=
k
*
tile_len
k_end
=
min
(
k_start
+
tile_len
,
K
)
k_size
=
k_end
-
k_start
qx_block
=
(
qx
[
m_start
:
m_end
,
k_start
:
k_end
].
clone
().
contiguous
()
)
# Shape: [m_size, k_size]
qw_block
=
(
qw
[
n_start
:
n_end
,
k_start
:
k_end
].
clone
().
contiguous
()
)
# Shape: [n_size, k_size]
# Extract scaling factors for the current blocks
sx_block
=
sx
[
i
*
scale_m_per_tile
:
(
i
+
1
)
*
scale_m_per_tile
,
k
].
unsqueeze
(
-
1
)
sw_block
=
sw
[
j
*
scale_n_per_tile
:
(
j
+
1
)
*
scale_n_per_tile
,
k
].
unsqueeze
(
0
)
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
y_partial
=
torch
.
_scaled_mm
(
qx_block
,
qw_block
.
t
(),
scale_a
=
one
,
scale_b
=
one
,
out_dtype
=
torch
.
float32
,
use_fast_accum
=
not
use_split_accumulator
,
)
# Accumulate the partial result
if
is_a_1d_scaled
and
is_b_1d_scaled
:
# 1Dx1D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial
=
y_partial
*
sx_block
# Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM
# y_block.add_(y_partial, alpha=scale.item())
fused_fma
(
y_block
,
y_partial
,
sw_block
.
expand_as
(
y_partial
).
contiguous
(),
)
elif
not
is_a_1d_scaled
and
is_b_1d_scaled
:
# 2Dx1D
# CuBLAS accumulation equation: y += (y * scale_b) * scale_a
y_partial
=
y_partial
*
sw_block
fused_fma
(
y_block
,
y_partial
,
sx_block
.
expand_as
(
y_partial
).
contiguous
(),
)
elif
is_a_1d_scaled
and
not
is_b_1d_scaled
:
# 1Dx2D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial
=
y_partial
*
sx_block
fused_fma
(
y_block
,
y_partial
,
sw_block
.
expand_as
(
y_partial
).
contiguous
(),
)
else
:
scale
=
sx_block
*
sw_block
fused_fma
(
y_block
,
y_partial
,
scale
.
expand_as
(
y_partial
).
contiguous
())
y
=
y
.
to
(
out_dtype
)
return
y
tests/pytorch/references/blockwise_quantizer_reference.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
dataclasses
import
math
import
torch
from
typing
import
Optional
,
Protocol
,
Tuple
from
references.quantize_scale_calc
import
scale_from_amax_tensor
@
dataclasses
.
dataclass
()
class
QuantizeResult
:
data
:
torch
.
Tensor
scale
:
torch
.
Tensor
data_t
:
Optional
[
torch
.
Tensor
]
scale_t
:
Optional
[
torch
.
Tensor
]
@
dataclasses
.
dataclass
()
class
CuBLASScaleMunger
:
def
munge_scale_shapes_for_backend
(
self
,
unmunged
:
QuantizeResult
,
tile_shape
:
Tuple
[
int
,
int
],
)
->
QuantizeResult
:
"""
cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed
so that for an (M, N) tensor, the scales are (RoundUpDiv(N, 128), RoundUp(M, 4))
For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4))
format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required
"""
def
_pad_inner_to_align
(
s
:
torch
.
Tensor
,
transpose
:
bool
)
->
torch
.
Tensor
:
if
transpose
:
s
=
s
.
transpose
(
-
1
,
-
2
).
contiguous
()
M
,
K
=
s
.
shape
if
K
%
4
==
0
:
return
s
k_pad
=
4
-
(
K
%
4
)
return
torch
.
nn
.
functional
.
pad
(
s
,
(
0
,
k_pad
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
s
=
_pad_inner_to_align
(
unmunged
.
scale
,
transpose
=
tile_shape
[
0
]
==
1
)
if
unmunged
.
scale_t
is
None
:
s_t
=
None
else
:
s_t
=
_pad_inner_to_align
(
unmunged
.
scale_t
,
transpose
=
tile_shape
[
0
]
==
1
)
return
QuantizeResult
(
unmunged
.
data
,
s
,
unmunged
.
data_t
,
s_t
)
@
classmethod
def
demunge_scale_shape_from_backend
(
cls
,
qtensor_shape
:
Tuple
[
int
,
int
],
scales
:
torch
.
Tensor
,
tile_shape
:
Tuple
[
int
,
int
],
)
->
torch
.
Tensor
:
"""
Inverse operation of munge_scale_shapes_for_backend
"""
if
tile_shape
[
0
]
!=
1
:
# 2D block quantized tensor may need padding stripped off
derived_scale_k_shape
=
math
.
ceil
(
qtensor_shape
[
1
]
/
tile_shape
[
1
])
else
:
derived_scale_k_shape
=
qtensor_shape
[
0
]
M
,
K
=
scales
.
shape
if
derived_scale_k_shape
!=
K
:
scales
=
scales
[:,
:
derived_scale_k_shape
].
contiguous
()
if
tile_shape
[
0
]
==
1
:
return
scales
.
transpose
(
-
1
,
-
2
).
contiguous
()
else
:
return
scales
@
dataclasses
.
dataclass
()
class
BlockwiseQuantizerReference
:
"""
A reference QuantizeOp for subchannel/block hybrid quantization.
Defers to ref GEMMs and quantizization formatting based on the backend.
"""
def
__init__
(
self
)
->
None
:
self
.
scale_munger
=
CuBLASScaleMunger
()
@
classmethod
def
_quantize_square_block_tiling
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
tile_len
:
int
,
*
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
eps
:
float
,
)
->
QuantizeResult
:
M
,
K
=
x
.
shape
pad_m_k
=
[
0
,
0
]
if
K
%
tile_len
!=
0
:
pad_m_k
[
1
]
=
tile_len
-
(
K
%
tile_len
)
if
M
%
tile_len
!=
0
:
pad_m_k
[
0
]
=
tile_len
-
(
M
%
tile_len
)
unpadded_m
,
unpadded_k
=
M
,
K
if
pad_m_k
[
0
]
!=
0
or
pad_m_k
[
1
]
!=
0
:
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
pad_m_k
[
1
],
0
,
pad_m_k
[
0
]),
mode
=
"constant"
,
value
=
0
).
contiguous
()
M
,
K
=
x
.
shape
x_tiled
=
x
.
reshape
(
M
//
tile_len
,
tile_len
,
K
//
tile_len
,
tile_len
)
amax_grid
=
(
torch
.
abs
(
x_tiled
.
transpose
(
-
3
,
-
2
))
.
reshape
(
M
//
tile_len
,
K
//
tile_len
,
tile_len
**
2
)
.
amax
(
dim
=-
1
)
).
float
()
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
x_dtype
=
x
.
dtype
,
amax
=
amax_grid
,
quant_dtype
=
quant_dtype
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
qx
=
x_tiled
*
scale
.
reshape
(
M
//
tile_len
,
1
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
if
unpadded_k
!=
K
or
unpadded_m
!=
M
:
qx
=
qx
[:
unpadded_m
,
:
unpadded_k
].
contiguous
()
if
return_transpose
:
# Valid because of square block sizes
qx_t
=
qx
.
transpose
(
-
1
,
-
2
).
contiguous
()
scale_inv_t
=
scale_inv
.
transpose
(
-
1
,
-
2
).
contiguous
()
else
:
qx_t
=
None
scale_inv_t
=
None
return
QuantizeResult
(
data
=
qx
,
scale
=
scale_inv
,
data_t
=
qx_t
,
scale_t
=
scale_inv_t
)
@
classmethod
def
_quantize_vectorwise_reference
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
tile_len
:
int
,
*
,
pow_2_scales
:
bool
,
eps
:
float
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
K
=
x
.
shape
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
x_tiled
=
x
.
reshape
(
M
,
K
//
tile_len
,
tile_len
)
amax_grid
=
torch
.
abs
(
x_tiled
).
amax
(
dim
=-
1
).
float
()
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
x_dtype
=
x
.
dtype
,
amax
=
amax_grid
,
quant_dtype
=
quant_dtype
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
return
qx
,
scale_inv
@
classmethod
def
_quantize_vector_tiling
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
tile_len
:
int
,
*
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
eps
:
float
,
)
->
QuantizeResult
:
M
,
K
=
x
.
shape
if
K
%
tile_len
==
0
:
qref_input
=
x
else
:
pad_amount
=
tile_len
-
(
K
%
tile_len
)
pad
=
(
0
,
pad_amount
)
qref_input
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
qout_padded
,
scale_inv
=
cls
.
_quantize_vectorwise_reference
(
qref_input
,
quant_dtype
,
tile_len
=
tile_len
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
if
K
%
tile_len
==
0
:
qout
=
qout_padded
else
:
qout
=
qout_padded
[:,
:
K
].
contiguous
()
if
return_transpose
:
if
M
%
tile_len
==
0
:
qref_input
=
x
.
transpose
(
-
1
,
-
2
).
contiguous
()
else
:
amount_to_pad
=
tile_len
-
(
M
%
tile_len
)
pad
=
(
0
,
amount_to_pad
)
qref_input
=
torch
.
nn
.
functional
.
pad
(
x
.
transpose
(
-
1
,
-
2
),
pad
,
mode
=
"constant"
,
value
=
0
).
contiguous
()
qout_t_padded
,
scale_inv_t
=
cls
.
_quantize_vectorwise_reference
(
qref_input
,
quant_dtype
,
tile_len
=
tile_len
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
if
M
%
tile_len
==
0
:
qout_t
=
qout_t_padded
else
:
qout_t
=
qout_t_padded
[:,
:
M
].
contiguous
()
else
:
qout_t
,
scale_inv_t
=
None
,
None
return
QuantizeResult
(
data
=
qout
,
scale
=
scale_inv
,
data_t
=
qout_t
,
scale_t
=
scale_inv_t
)
def
ref_dequantize_rowwise
(
self
,
q
:
torch
.
Tensor
,
quant_tile_shape
:
Tuple
[
int
,
int
],
s
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
assert
q
.
dim
()
==
2
q_M
,
q_K
=
q
.
shape
s
=
self
.
scale_munger
.
demunge_scale_shape_from_backend
((
q_M
,
q_K
),
s
,
quant_tile_shape
)
assert
len
(
s
.
shape
)
==
2
m_tiles
,
k_tiles
=
s
.
shape
M
,
K
=
q
.
shape
unpadded_m
,
unpadded_k
=
M
,
K
if
M
%
quant_tile_shape
[
0
]
!=
0
or
K
%
quant_tile_shape
[
1
]
!=
0
:
m_pad_amount
=
(
quant_tile_shape
[
0
]
-
(
M
%
quant_tile_shape
[
0
]))
%
quant_tile_shape
[
0
]
k_pad_amount
=
(
quant_tile_shape
[
1
]
-
(
K
%
quant_tile_shape
[
1
]))
%
quant_tile_shape
[
1
]
q
=
torch
.
nn
.
functional
.
pad
(
q
,
(
0
,
k_pad_amount
,
0
,
m_pad_amount
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
M
,
K
=
q
.
shape
q_tiled
=
q
.
reshape
(
m_tiles
,
quant_tile_shape
[
0
],
k_tiles
,
quant_tile_shape
[
1
])
result
=
q_tiled
.
to
(
dtype
)
*
s
.
reshape
(
m_tiles
,
1
,
k_tiles
,
1
)
result
=
result
.
view
(
M
,
K
).
to
(
dtype
)
if
M
!=
unpadded_m
or
K
!=
unpadded_k
:
result
=
result
[:
unpadded_m
,
:
unpadded_k
].
contiguous
()
return
result
def
quantize
(
self
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
return_transpose
:
bool
=
False
,
eps
:
float
=
0.0
,
pow_2_scales
:
bool
=
False
,
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
128
,
128
),
)
->
QuantizeResult
:
# sanity checks
assert
x
.
dim
()
==
2
assert
x
.
dtype
in
(
torch
.
float
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
),
"Unsupported input dtype."
assert
quant_dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
),
"Unsupported quant dtype."
assert
quant_tile_shape
in
((
1
,
128
),
(
128
,
128
))
if
quant_tile_shape
[
0
]
==
1
:
# Quantize row-wise
return
self
.
scale_munger
.
munge_scale_shapes_for_backend
(
self
.
_quantize_vector_tiling
(
x
,
quant_dtype
,
tile_len
=
quant_tile_shape
[
1
],
return_transpose
=
return_transpose
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
),
quant_tile_shape
,
)
else
:
# Quantize block-wise
return
self
.
scale_munger
.
munge_scale_shapes_for_backend
(
self
.
_quantize_square_block_tiling
(
x
,
quant_dtype
,
tile_len
=
quant_tile_shape
[
0
],
return_transpose
=
return_transpose
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
),
quant_tile_shape
,
)
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment