Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2561 additions
and
566 deletions
+2561
-566
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+1
-0
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+655
-0
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+30
-159
tests/cpp/operator/test_normalization.h
tests/cpp/operator/test_normalization.h
+188
-0
tests/cpp/operator/test_normalization_mxfp8.cu
tests/cpp/operator/test_normalization_mxfp8.cu
+24
-77
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+97
-53
tests/cpp/test_common.h
tests/cpp/test_common.h
+11
-12
tests/jax/pytest.ini
tests/jax/pytest.ini
+2
-0
tests/jax/test_custom_call_compute.py
tests/jax/test_custom_call_compute.py
+152
-122
tests/jax/test_distributed_fused_attn.py
tests/jax/test_distributed_fused_attn.py
+216
-69
tests/jax/test_distributed_layernorm.py
tests/jax/test_distributed_layernorm.py
+7
-1
tests/jax/test_distributed_layernorm_mlp.py
tests/jax/test_distributed_layernorm_mlp.py
+112
-46
tests/jax/test_distributed_softmax.py
tests/jax/test_distributed_softmax.py
+82
-14
tests/jax/test_layer.py
tests/jax/test_layer.py
+46
-5
tests/jax/test_softmax.py
tests/jax/test_softmax.py
+91
-5
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
+273
-1
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+23
-1
tests/pytorch/distributed/test_numerics.py
tests/pytorch/distributed/test_numerics.py
+6
-1
tests/pytorch/references/blockwise_fp8_gemm_reference.py
tests/pytorch/references/blockwise_fp8_gemm_reference.py
+242
-0
tests/pytorch/references/blockwise_quantizer_reference.py
tests/pytorch/references/blockwise_quantizer_reference.py
+303
-0
No files found.
tests/cpp/operator/CMakeLists.txt
View file @
ab3e5a92
...
...
@@ -11,6 +11,7 @@ list(APPEND test_cuda_sources
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
# test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
test_transpose.cu
test_cast_transpose.cu
...
...
tests/cpp/operator/test_cast_float8blockwise.cu
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"
#include "transformer_engine/transformer_engine.h"
using
namespace
transformer_engine
;
using
namespace
test
;
namespace
{
struct
QuantizationOptions
{
bool
force_pow_2_scales
=
false
;
float
amax_epsilon
=
0.0
;
size_t
block_scaling_dim
=
2u
;
};
constexpr
size_t
kBlockLen
=
128
;
enum
ProcessingMethod
{
CAST_ONLY
,
// CAST_DBIAS,
// CAST_DBIAS_DACT,
// CAST_DACT,
// CAST_ACT
};
enum
ActivationType
{
Identity
,
// GeLU,
// SiLU,
// ReLU,
// QGeLU,
// SReLU
};
template
<
typename
InputType
,
typename
OutputType
>
void
scales_from_amax
(
float
amax
,
const
QuantizationOptions
&
opts
,
float
*
qscale_out
,
float
*
qscale_inv_out
)
{
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
float
eps
=
opts
.
amax_epsilon
;
amax
=
std
::
max
(
amax
,
eps
);
float
qscale
=
quant_type_max_val
/
amax
;
if
(
std
::
isinf
(
qscale
))
{
qscale
=
input_type_max_val
;
}
if
(
std
::
isnan
(
qscale
)
||
amax
==
0
)
{
qscale
=
1.0
;
}
if
(
opts
.
force_pow_2_scales
&&
qscale
!=
0.0
)
{
uint32_t
scale_bits
=
*
reinterpret_cast
<
uint32_t
*>
(
&
qscale
);
// Scale must be positive, shift it
uint8_t
exp
=
scale_bits
>>
23
;
ASSERT_FALSE
(
exp
==
0
)
<<
"Subnormals in this path is a logic error."
;
qscale
=
ldexpf
(
1.0
f
,
static_cast
<
int32_t
>
(
exp
)
-
127
);
}
float
qscale_inv
=
1.0
/
qscale
;
*
qscale_out
=
qscale
;
*
qscale_inv_out
=
qscale_inv
;
}
template
<
typename
InputType
,
typename
OutputType
>
void
ref_quantize
(
const
ProcessingMethod
processing_method
,
const
InputType
*
input
,
const
std
::
pair
<
size_t
,
size_t
>&
input_hw
,
OutputType
*
output
,
float
*
scale_inv
,
OutputType
*
output_t
,
float
*
scale_inv_t
,
const
QuantizationOptions
&
opts
)
{
constexpr
size_t
kBlockLenX
=
kBlockLen
;
constexpr
size_t
kBlockLenY
=
kBlockLen
;
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
// Scale in FP32 and cast result to nearest FP8.
return
static_cast
<
OutputType
>
(
float
(
element
)
*
qscale
);
};
size_t
height
=
input_hw
.
first
;
size_t
width
=
input_hw
.
second
;
size_t
blocks_x
=
(
width
+
kBlockLenX
-
1
)
/
kBlockLenX
;
size_t
blocks_y
=
(
height
+
kBlockLenY
-
1
)
/
kBlockLenY
;
// Find the absolute maximum value in the block
for
(
size_t
block_x
=
0
;
block_x
<
blocks_x
;
++
block_x
)
{
for
(
size_t
block_y
=
0
;
block_y
<
blocks_y
;
++
block_y
)
{
float
amax
=
0.0
f
;
// Calculate amax for a tile.
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
kBlockLenY
;
++
j
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
size_t
y_pos
=
j
+
block_y
*
kBlockLenY
;
if
(
y_pos
>=
height
||
x_pos
>=
width
)
{
continue
;
}
float
val
=
static_cast
<
float
>
(
input
[
y_pos
*
width
+
x_pos
]);
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
}
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float
qscale
,
qscale_inv
;
scales_from_amax
<
InputType
,
OutputType
>
(
amax
,
opts
,
&
qscale
,
&
qscale_inv
);
// NOTE: This reference function outputs contigous scale tensors.
// It calculates a naive scale data format. Strides are handled
// in comparison.
if
(
scale_inv
!=
nullptr
)
{
scale_inv
[
block_y
*
blocks_x
+
block_x
]
=
qscale_inv
;
}
if
(
scale_inv_t
!=
nullptr
)
{
scale_inv_t
[
block_x
*
blocks_y
+
block_y
]
=
qscale_inv
;
}
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
kBlockLenY
;
++
j
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
size_t
y_pos
=
j
+
block_y
*
kBlockLenY
;
if
(
y_pos
>=
height
||
x_pos
>=
width
)
{
continue
;
}
if
(
output
!=
nullptr
)
{
output
[
y_pos
*
width
+
x_pos
]
=
quantize_element
(
input
[
y_pos
*
width
+
x_pos
],
qscale
);
}
if
(
output_t
!=
nullptr
)
{
output_t
[
x_pos
*
height
+
y_pos
]
=
quantize_element
(
input
[
y_pos
*
width
+
x_pos
],
qscale
);
}
}
}
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
ref_quantize_onedimensional_blocks
(
const
ProcessingMethod
processing_method
,
const
InputType
*
input
,
const
std
::
pair
<
size_t
,
size_t
>&
input_hw
,
OutputType
*
output
,
float
*
scale_inv
,
OutputType
*
output_t
,
float
*
scale_inv_t
,
const
QuantizationOptions
&
opts
)
{
float
input_type_max_val
=
Quantized_Limits
<
InputType
>::
max
();
float
quant_type_max_val
=
Quantized_Limits
<
OutputType
>::
max
();
constexpr
size_t
kBlockLenX
=
kBlockLen
;
auto
quantize_element
=
[](
InputType
element
,
float
qscale
)
->
OutputType
{
// Scale in FP32 and cast result to nearest FP8.
return
static_cast
<
OutputType
>
(
float
(
element
)
*
qscale
);
};
size_t
height
=
input_hw
.
first
;
size_t
width
=
input_hw
.
second
;
size_t
blocks_x
=
(
width
+
kBlockLenX
-
1
)
/
kBlockLenX
;
size_t
blocks_x_t
=
(
height
+
kBlockLenX
-
1
)
/
kBlockLenX
;
if
(
output
!=
nullptr
&&
scale_inv
!=
nullptr
)
{
// Find the absolute maximum value in the block
for
(
size_t
block_x
=
0
;
block_x
<
blocks_x
;
++
block_x
)
{
for
(
size_t
y
=
0
;
y
<
height
;
++
y
)
{
float
amax
=
0.0
f
;
// Calculate amax for a tile.
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
if
(
x_pos
>=
width
)
{
continue
;
}
float
val
=
static_cast
<
float
>
(
input
[
y
*
width
+
x_pos
]);
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float
qscale
,
qscale_inv
;
scales_from_amax
<
InputType
,
OutputType
>
(
amax
,
opts
,
&
qscale
,
&
qscale_inv
);
scale_inv
[
y
+
height
*
block_x
]
=
qscale_inv
;
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
x_pos
=
i
+
block_x
*
kBlockLenX
;
if
(
x_pos
>=
width
)
{
continue
;
}
output
[
y
*
width
+
x_pos
]
=
quantize_element
(
input
[
y
*
width
+
x_pos
],
qscale
);
}
}
}
}
if
(
output_t
!=
nullptr
&&
scale_inv_t
!=
nullptr
)
{
// Find the absolute maximum value in the block
for
(
size_t
block_x_t
=
0
;
block_x_t
<
blocks_x_t
;
++
block_x_t
)
{
for
(
size_t
x
=
0
;
x
<
width
;
++
x
)
{
float
amax
=
0.0
f
;
// Calculate amax for a tile.
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
y_pos
=
i
+
block_x_t
*
kBlockLenX
;
if
(
y_pos
>=
height
)
{
continue
;
}
float
val
=
static_cast
<
float
>
(
input
[
x
+
y_pos
*
width
]);
amax
=
std
::
max
(
amax
,
std
::
abs
(
val
));
}
// We've calculated amax for a tile. Calculate scale and
// scale_inv and populate outputs.
float
qscale
,
qscale_inv
;
scales_from_amax
<
InputType
,
OutputType
>
(
amax
,
opts
,
&
qscale
,
&
qscale_inv
);
scale_inv_t
[
x
+
width
*
block_x_t
]
=
qscale_inv
;
for
(
size_t
i
=
0
;
i
<
kBlockLenX
;
++
i
)
{
size_t
y_pos
=
i
+
block_x_t
*
kBlockLenX
;
if
(
y_pos
>=
height
)
{
continue
;
}
output_t
[
x
*
height
+
y_pos
]
=
quantize_element
(
input
[
y_pos
*
width
+
x
],
qscale
);
}
}
}
}
}
inline
size_t
scale_align_stride
(
size_t
inner_elements
)
{
return
((
inner_elements
+
4u
-
1u
)
/
4u
)
*
4u
;
};
void
compare_scaling_factors
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
row_blocks
,
const
size_t
col_blocks
,
const
size_t
test_stride
,
const
size_t
ref_stride
)
{
for
(
int
i
=
0
;
i
<
row_blocks
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
*
test_stride
+
j
;
const
int
ref_idx
=
i
*
ref_stride
+
j
;
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
}
}
}
void
compare_scaling_factors_one_dimensional_blocks
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
rows
,
const
size_t
col_blocks
)
{
const
size_t
test_stride
=
scale_align_stride
(
rows
);
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
+
test_stride
*
j
;
const
int
ref_idx
=
i
+
rows
*
j
;
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
runTestCase
(
const
ProcessingMethod
processing_method
,
const
std
::
vector
<
size_t
>&
shape
,
const
bool
rowwise
,
const
bool
colwise
,
InputsFillCase
fill_case
,
const
QuantizationOptions
&
opts
)
{
using
namespace
test
;
using
EncodingType
=
fp32
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
const
size_t
rows
=
first_dimension
(
shape
);
const
size_t
cols
=
last_dimension
(
shape
);
size_t
blocks_x
=
(
cols
+
kBlockLen
-
1
)
/
kBlockLen
;
size_t
blocks_y
=
(
rows
+
kBlockLen
-
1
)
/
kBlockLen
;
Tensor
input
(
"input"
,
shape
,
itype
);
Tensor
grad
(
"grad"
,
shape
,
itype
);
Tensor
output_c
(
"output_c"
,
shape
,
otype
,
rowwise
,
colwise
,
opts
.
block_scaling_dim
==
2
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
Tensor
output_dbias
(
"output_dbias"
,
{
cols
},
itype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output_t
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv
=
std
::
make_unique
<
float
[]
>
(
blocks_y
*
blocks_x
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv_t
=
std
::
make_unique
<
float
[]
>
(
blocks_y
*
blocks_x
);
if
(
!
rowwise
)
{
ref_output
=
nullptr
;
ref_scale_inv
=
nullptr
;
}
if
(
!
colwise
)
{
ref_output_t
=
nullptr
;
ref_scale_inv_t
=
nullptr
;
}
fillCase
<
EncodingType
>
(
&
input
,
fill_case
);
fillUniform
(
&
grad
);
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
opts
.
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
opts
.
amax_epsilon
);
Tensor
workspace
;
switch
(
processing_method
)
{
case
ProcessingMethod
::
CAST_ONLY
:
{
nvte_quantize_v2
(
input
.
data
(),
output_c
.
data
(),
quant_config
,
nullptr
);
break
;
}
}
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
ref_quantize
<
InputType
,
OutputType
>
(
processing_method
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
{
rows
,
cols
},
ref_output
.
get
(),
ref_scale_inv
.
get
(),
ref_output_t
.
get
(),
ref_scale_inv_t
.
get
(),
opts
);
float
atol
=
0.0
;
float
rtol
=
0.0
;
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
blocks_y
,
blocks_x
,
scale_align_stride
(
blocks_x
),
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
blocks_x
,
blocks_y
,
scale_align_stride
(
blocks_y
),
blocks_y
);
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
runTestCaseOneDimensionalBlocks
(
const
ProcessingMethod
processing_method
,
const
std
::
vector
<
size_t
>&
shape
,
const
bool
rowwise
,
const
bool
colwise
,
InputsFillCase
fill_case
,
const
QuantizationOptions
&
opts
)
{
using
namespace
test
;
using
EncodingType
=
fp32
;
DType
itype
=
TypeInfo
<
InputType
>::
dtype
;
DType
otype
=
TypeInfo
<
OutputType
>::
dtype
;
const
size_t
rows
=
first_dimension
(
shape
);
const
size_t
cols
=
last_dimension
(
shape
);
size_t
blocks_x
=
(
cols
+
kBlockLen
-
1
)
/
kBlockLen
;
size_t
blocks_x_t
=
(
rows
+
kBlockLen
-
1
)
/
kBlockLen
;
Tensor
input
(
"input"
,
shape
,
itype
);
Tensor
grad
(
"grad"
,
shape
,
itype
);
Tensor
output_c
(
"output_c"
,
shape
,
otype
,
rowwise
,
colwise
,
opts
.
block_scaling_dim
==
2
?
NVTE_BLOCK_SCALING_2D
:
NVTE_BLOCK_SCALING_1D
);
Tensor
output_dbias
(
"output_dbias"
,
{
cols
},
itype
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
OutputType
[]
>
ref_output_t
=
std
::
make_unique
<
OutputType
[]
>
(
rows
*
cols
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv
=
std
::
make_unique
<
float
[]
>
(
rows
*
blocks_x
);
std
::
unique_ptr
<
float
[]
>
ref_scale_inv_t
=
std
::
make_unique
<
float
[]
>
(
cols
*
blocks_x_t
);
if
(
!
rowwise
)
{
ref_output
=
nullptr
;
ref_scale_inv
=
nullptr
;
}
if
(
!
colwise
)
{
ref_output_t
=
nullptr
;
ref_scale_inv_t
=
nullptr
;
}
fillCase
<
EncodingType
>
(
&
input
,
fill_case
);
fillUniform
(
&
grad
);
Tensor
workspace
;
QuantizationConfigWrapper
quant_config
;
quant_config
.
set_force_pow_2_scales
(
opts
.
force_pow_2_scales
);
quant_config
.
set_amax_epsilon
(
opts
.
amax_epsilon
);
switch
(
processing_method
)
{
case
ProcessingMethod
::
CAST_ONLY
:
{
nvte_quantize_v2
(
input
.
data
(),
output_c
.
data
(),
quant_config
,
nullptr
);
break
;
}
}
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
ref_quantize_onedimensional_blocks
<
InputType
,
OutputType
>
(
processing_method
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
{
rows
,
cols
},
ref_output
.
get
(),
ref_scale_inv
.
get
(),
ref_output_t
.
get
(),
ref_scale_inv_t
.
get
(),
opts
);
float
atol
=
0.0
;
float
rtol
=
0.0
;
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
rows
,
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
);
}
}
std
::
vector
<
std
::
vector
<
size_t
>>
matrix_sizes
=
{
{
1
,
16
},
{
65
,
96
},
{
256
,
256
},
{
993
,
512
},
{
256
,
65536
},
{
4096
,
1632
},
{
1024
,
1
},
{
16
,
512
},
{
1024
},
{
8
,
32
,
1024
},
{
16
,
8
,
4
,
512
},
};
std
::
vector
<
InputsFillCase
>
input_scenarios
=
{
InputsFillCase
::
uniform
,
};
std
::
vector
<
ProcessingMethod
>
processing_methods
=
{
ProcessingMethod
::
CAST_ONLY
,
// ProcessingMethod::CAST_DBIAS,
// ProcessingMethod::CAST_DBIAS_DACT,
// ProcessingMethod::CAST_DACT,
// ProcessingMethod::CAST_ACT,
};
// Only GeLU activation tests are supported
std
::
vector
<
ActivationType
>
Activation_types
=
{
ActivationType
::
Identity
,
// ActivationType::GeLU,
// ActivationType::SiLU,
// ActivationType::ReLU,
// ActivationType::QGeLU,
// ActivationType::SReLU,
};
std
::
vector
<
float
>
amax_epsilons
=
{
0.0
f
,
1.0
f
,
// Make large to be observable.
};
}
// namespace
class
FusedCastFloat8BlockwiseTestSuite
:
public
::
testing
::
TestWithParam
<
std
::
tuple
<
ProcessingMethod
,
ActivationType
,
std
::
vector
<
size_t
>
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
InputsFillCase
,
bool
,
float
,
bool
>>
{};
class
FusedCastFloat8VectorwiseTestSuite
:
public
::
testing
::
TestWithParam
<
std
::
tuple
<
ProcessingMethod
,
ActivationType
,
std
::
vector
<
size_t
>
,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
InputsFillCase
,
bool
,
float
,
bool
>>
{};
#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { \
constexpr auto OP = &identity; \
{ \
__VA_ARGS__ \
} \
} break; \
}
#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \
switch (OP_FUNC_TYPE) { \
case ActivationType::Identity: { \
constexpr auto OP = &identity; \
{ \
__VA_ARGS__ \
} \
} break; \
}
TEST_P
(
FusedCastFloat8BlockwiseTestSuite
,
TestFusedCastFloat8Blockwise
)
{
if
(
getDeviceComputeCapability
()
<
hopperComputeCapability
)
{
GTEST_SKIP
();
}
using
namespace
transformer_engine
;
using
namespace
test
;
const
ProcessingMethod
processing_method
=
std
::
get
<
0
>
(
GetParam
());
const
ActivationType
Act_type
=
std
::
get
<
1
>
(
GetParam
());
const
auto
matrix_size
=
std
::
get
<
2
>
(
GetParam
());
const
DType
input_type
=
std
::
get
<
3
>
(
GetParam
());
const
DType
output_type
=
std
::
get
<
4
>
(
GetParam
());
const
InputsFillCase
fill_case
=
std
::
get
<
5
>
(
GetParam
());
const
bool
colwise
=
std
::
get
<
6
>
(
GetParam
());
const
bool
rowwise
=
true
;
const
float
eps
=
std
::
get
<
7
>
(
GetParam
());
const
bool
force_pow_2
=
std
::
get
<
8
>
(
GetParam
());
QuantizationOptions
q_opts
;
q_opts
.
force_pow_2_scales
=
force_pow_2
;
q_opts
.
amax_epsilon
=
eps
;
q_opts
.
block_scaling_dim
=
2u
;
if
(
colwise
&&
matrix_size
.
size
()
<
2
)
{
// test_common Tensor initialization code does not
// handle this case.
GTEST_SKIP
();
}
// Skips non Act tests if the Activation type is not an identity
if
(
// (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
(
processing_method
==
ProcessingMethod
::
CAST_ONLY
)
&&
Act_type
!=
ActivationType
::
Identity
)
{
GTEST_SKIP
();
}
// Skips Act tests if the Activation is an identity
// if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
// || processing_method == ProcessingMethod::CAST_DACT
// || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
// GTEST_SKIP();
// }
DACT_FUNC_SWITCH
(
Act_type
,
OP
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
runTestCase
<
InputType
,
OutputType
>
(
processing_method
,
matrix_size
,
rowwise
,
colwise
,
fill_case
,
q_opts
););););
}
TEST_P
(
FusedCastFloat8VectorwiseTestSuite
,
TestFusedCastFloat8Vectorwise
)
{
if
(
getDeviceComputeCapability
()
<
hopperComputeCapability
)
{
GTEST_SKIP
();
}
using
namespace
transformer_engine
;
using
namespace
test
;
const
ProcessingMethod
processing_method
=
std
::
get
<
0
>
(
GetParam
());
const
ActivationType
Act_type
=
std
::
get
<
1
>
(
GetParam
());
const
auto
matrix_size
=
std
::
get
<
2
>
(
GetParam
());
const
DType
input_type
=
std
::
get
<
3
>
(
GetParam
());
const
DType
output_type
=
std
::
get
<
4
>
(
GetParam
());
const
InputsFillCase
fill_case
=
std
::
get
<
5
>
(
GetParam
());
const
bool
colwise
=
std
::
get
<
6
>
(
GetParam
());
const
bool
rowwise
=
true
;
const
float
eps
=
std
::
get
<
7
>
(
GetParam
());
const
bool
force_pow_2
=
std
::
get
<
8
>
(
GetParam
());
QuantizationOptions
q_opts
;
q_opts
.
force_pow_2_scales
=
force_pow_2
;
q_opts
.
amax_epsilon
=
eps
;
q_opts
.
block_scaling_dim
=
1u
;
if
(
colwise
&&
matrix_size
.
size
()
<
2
)
{
// test_common Tensor initialization code does not
// handle this case.
GTEST_SKIP
();
}
// Skips non Act tests if the Activation type is not an identity
if
(
// (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS)
(
processing_method
==
ProcessingMethod
::
CAST_ONLY
)
&&
Act_type
!=
ActivationType
::
Identity
)
{
GTEST_SKIP
();
}
// Skips Act tests if the Activation is an identity
// if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT
// || processing_method == ProcessingMethod::CAST_DACT
// || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) {
// GTEST_SKIP();
// }
DACT_FUNC_SWITCH
(
Act_type
,
OP
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
runTestCaseOneDimensionalBlocks
<
InputType
,
OutputType
>
(
processing_method
,
matrix_size
,
rowwise
,
colwise
,
fill_case
,
q_opts
););););
}
std
::
string
to_string
(
const
ProcessingMethod
method
)
{
switch
(
method
)
{
case
ProcessingMethod
::
CAST_ONLY
:
return
"CAST_ONLY"
;
// case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS";
// case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT";
// case ProcessingMethod::CAST_DACT: return "CAST_DACT";
// case ProcessingMethod::CAST_ACT: return "CAST_ACT";
default:
return
""
;
}
}
std
::
string
to_string
(
const
ActivationType
Act_type
)
{
switch
(
Act_type
)
{
case
ActivationType
::
Identity
:
return
"Identity"
;
// case ActivationType::GeLU: return "GeLU";
// case ActivationType::SiLU: return "SiLU";
// case ActivationType::ReLU: return "ReLU";
// case ActivationType::QGeLU: return "QGeLU";
// case ActivationType::SReLU: return "SReLU";
default:
return
""
;
}
}
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
FusedCastFloat8BlockwiseTestSuite
,
::
testing
::
Combine
(
::
testing
::
ValuesIn
(
processing_methods
),
::
testing
::
ValuesIn
(
Activation_types
),
::
testing
::
ValuesIn
(
matrix_sizes
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
),
::
testing
::
ValuesIn
(
input_scenarios
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
ValuesIn
(
amax_epsilons
),
::
testing
::
Values
(
true
,
false
)),
[](
const
testing
::
TestParamInfo
<
FusedCastFloat8BlockwiseTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
to_string
(
std
::
get
<
0
>
(
info
.
param
))
+
"X"
+
to_string
(
std
::
get
<
1
>
(
info
.
param
));
const
auto
&
shape
=
std
::
get
<
2
>
(
info
.
param
);
for
(
const
auto
&
s
:
shape
)
{
name
+=
"X"
+
std
::
to_string
(
s
);
}
name
+=
"X"
+
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
test
::
typeName
(
std
::
get
<
4
>
(
info
.
param
))
+
"X"
+
test
::
caseName
(
std
::
get
<
5
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
7
>
(
info
.
param
)
!=
0.0
f
)
+
"X"
+
std
::
to_string
(
std
::
get
<
8
>
(
info
.
param
));
return
name
;
});
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
FusedCastFloat8VectorwiseTestSuite
,
::
testing
::
Combine
(
::
testing
::
ValuesIn
(
processing_methods
),
::
testing
::
ValuesIn
(
Activation_types
),
::
testing
::
ValuesIn
(
matrix_sizes
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat8E4M3
,
DType
::
kFloat8E5M2
),
::
testing
::
ValuesIn
(
input_scenarios
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
ValuesIn
(
amax_epsilons
),
::
testing
::
Values
(
true
,
false
)),
[](
const
testing
::
TestParamInfo
<
FusedCastFloat8VectorwiseTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
to_string
(
std
::
get
<
0
>
(
info
.
param
))
+
"X"
+
to_string
(
std
::
get
<
1
>
(
info
.
param
));
const
auto
&
shape
=
std
::
get
<
2
>
(
info
.
param
);
for
(
const
auto
&
s
:
shape
)
{
name
+=
"X"
+
std
::
to_string
(
s
);
}
name
+=
"X"
+
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
test
::
typeName
(
std
::
get
<
4
>
(
info
.
param
))
+
"X"
+
test
::
caseName
(
std
::
get
<
5
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
7
>
(
info
.
param
)
!=
0.0
f
)
+
"X"
+
std
::
to_string
(
std
::
get
<
8
>
(
info
.
param
));
return
name
;
});
tests/cpp/operator/test_normalization.cu
View file @
ab3e5a92
...
...
@@ -19,169 +19,16 @@
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
#include "test_normalization.h"
using
namespace
transformer_engine
;
using
namespace
test
;
namespace
{
enum
NormType
{
LayerNorm
,
RMSNorm
};
std
::
map
<
NormType
,
std
::
string
>
normToString
=
{
{
NormType
::
LayerNorm
,
"LayerNorm"
},
{
NormType
::
RMSNorm
,
"RmsNorm"
}
};
template
<
typename
InputType
>
void
compute_ref_stats
(
NormType
norm_type
,
const
InputType
*
data
,
float
*
mu
,
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
const
double
epsilon
){
using
compute_t
=
float
;
compute_t
current
,
m
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
sum
+=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
}
if
(
norm_type
==
LayerNorm
){
mu
[
i
]
=
sum
/
H
;
m
=
mu
[
i
];
}
else
{
m
=
0
;}
compute_t
sum_sq
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
// For now, cudnn does static_cast<compute_t>(gamma + static_cast<input_t>(1.0))
// This will be changed in the future release
template
<
typename
InputType
>
inline
auto
compute_gamma
(
InputType
gamma
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
){
using
compute_t
=
float
;
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
){
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
else
{
if
(
use_cudnn
){
compute_t
g
=
static_cast
<
compute_t
>
(
0.
f
);
#ifndef __HIP_PLATFORM_AMD__
InputType
gi
=
gamma
;
if
(
zero_centered_gamma
)
{
gi
=
gi
+
static_cast
<
InputType
>
(
1.
f
);
}
g
=
static_cast
<
compute_t
>
(
gi
);
#else
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
#endif
return
g
;
}
else
{
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_output
(
NormType
norm_type
,
const
InputType
*
data
,
const
InputType
*
gamma
,
const
InputType
*
beta
,
OutputType
*
output
,
const
float
*
mu
,
const
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
float
*
amax
,
float
scale
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
)
{
using
compute_t
=
float
;
compute_t
current_max
=
-
1e100
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
);
compute_t
tmp
;
if
(
norm_type
==
LayerNorm
)
{
tmp
=
(
current
-
mu
[
i
])
*
rsigma
[
i
]
*
g
+
static_cast
<
compute_t
>
(
beta
[
j
]);
}
else
{
// RMSNorm
tmp
=
current
*
rsigma
[
i
]
*
g
;
}
output
[
i
*
H
+
j
]
=
static_cast
<
OutputType
>
(
tmp
*
scale
);
current_max
=
fmaxf
(
current_max
,
fabsf
(
tmp
));
}
}
*
amax
=
current_max
;
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_backward
(
const
NormType
norm_type
,
const
OutputType
*
output_grad
,
const
InputType
*
data
,
const
float
*
mu
,
const
float
*
rsigma
,
const
InputType
*
gamma
,
InputType
*
data_grad
,
InputType
*
gamma_grad
,
InputType
*
beta_grad
,
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
)
{
using
compute_t
=
float
;
std
::
vector
<
compute_t
>
dgamma
(
H
,
0.
f
);
std
::
vector
<
compute_t
>
dbeta
(
H
,
0.
f
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
// Reductions
auto
local_mu
=
(
norm_type
==
LayerNorm
)
?
mu
[
i
]
:
0.
;
compute_t
mdy
=
0
,
mdyy
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
dgamma
[
j
]
+=
y
*
dz
;
if
(
norm_type
==
LayerNorm
)
{
dbeta
[
j
]
+=
dz
;
mdy
+=
dy
;
}
mdyy
+=
dy
*
y
;
}
mdy
/=
H
;
mdyy
/=
H
;
// Input grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
const
compute_t
dx
=
rsigma
[
i
]
*
(
dy
-
mdyy
*
y
-
mdy
);
data_grad
[
i
*
H
+
j
]
=
static_cast
<
InputType
>
(
dx
);
}
}
// Weight grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
gamma_grad
[
j
]
=
static_cast
<
InputType
>
(
dgamma
[
j
]);
if
(
norm_type
==
LayerNorm
)
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
beta_grad
[
j
]
=
static_cast
<
InputType
>
(
dbeta
[
j
]);
}
template
<
typename
InputType
,
typename
OutputType
>
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
NormType
norm_type
,
bool
use_cudnn
)
{
NormType
norm_type
,
bool
use_cudnn
,
const
bool
zero_centered_gamma_in_weight_dtype
)
{
if
(
sizeof
(
InputType
)
<
sizeof
(
OutputType
))
{
GTEST_SKIP
()
<<
"LN kernel does not support OutputType > InputType"
;
return
;
...
...
@@ -230,9 +77,22 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
0
);
if
((
!
use_cudnn
||
!
zero_centered_gamma
)
&&
zero_centered_gamma_in_weight_dtype
)
{
// Skip duplicate tests when zero_centered_gamma_in_weight_dtype is true and won't affect the implementation
GTEST_SKIP
()
<<
"Zero-centered gamma in weight dtype is only supported with cuDNN backend"
;
}
if
(
use_cudnn
){
nvte_enable_cudnn_norm_fwd
(
true
);
nvte_enable_cudnn_norm_bwd
(
true
);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
true
);
}
else
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
}
// Forward kernel
...
...
@@ -280,6 +140,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
if
(
use_cudnn
){
nvte_enable_cudnn_norm_fwd
(
false
);
nvte_enable_cudnn_norm_bwd
(
false
);
// Zero-centered gamma in weight dtype only supported by CuDNN backend currently
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
}
// Reference implementations
...
...
@@ -300,14 +165,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
&
ref_amax
,
ref_scale
,
zero_centered_gamma
,
use_cudnn
);
use_cudnn
,
zero_centered_gamma_in_weight_dtype
);
compute_ref_backward
(
norm_type
,
dz
.
rowwise_cpu_dptr
<
WeightType
>
(),
input
.
rowwise_cpu_dptr
<
InputType
>
(),
mu
.
rowwise_cpu_dptr
<
float
>
(),
rsigma
.
rowwise_cpu_dptr
<
float
>
(),
gamma
.
rowwise_cpu_dptr
<
WeightType
>
(),
ref_dx
.
get
(),
ref_dgamma
.
get
(),
ref_dbeta
.
get
(),
N
,
H
,
zero_centered_gamma
,
use_cudnn
);
use_cudnn
,
zero_centered_gamma_in_weight_dtype
);
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
...
...
@@ -352,6 +219,7 @@ NormType,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
std
::
pair
<
size_t
,
size_t
>
,
bool
,
bool
>>
{};
TEST_P
(
NormTestSuite
,
TestNorm
)
{
...
...
@@ -364,10 +232,11 @@ TEST_P(NormTestSuite, TestNorm) {
const
DType
output_type
=
std
::
get
<
3
>
(
GetParam
());
const
auto
size
=
std
::
get
<
4
>
(
GetParam
());
const
bool
zero_centered_gamma
=
std
::
get
<
5
>
(
GetParam
());
const
bool
cudnn_zero_centered_gamm_in_weight_dtype
=
std
::
get
<
6
>
(
GetParam
());
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL
(
output_type
,
OutputType
,
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
use_cudnn
);
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
use_cudnn
,
cudnn_zero_centered_gamm_in_weight_dtype
);
);
);
}
...
...
@@ -381,6 +250,7 @@ INSTANTIATE_TEST_SUITE_P(
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
),
::
testing
::
Values
(
DType
::
kFloat32
,
DType
::
kBFloat16
,
DType
::
kFloat16
,
DType
::
kFloat8E4M3
),
::
testing
::
ValuesIn
(
test_cases
),
::
testing
::
Values
(
false
,
true
),
::
testing
::
Values
(
false
,
true
)),
[](
const
testing
::
TestParamInfo
<
NormTestSuite
::
ParamType
>&
info
)
{
auto
backend
=
std
::
get
<
0
>
(
info
.
param
)
==
false
?
"Te"
:
"Cudnn"
;
...
...
@@ -391,6 +261,7 @@ INSTANTIATE_TEST_SUITE_P(
test
::
typeName
(
std
::
get
<
3
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
).
first
)
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
).
second
)
+
"X"
+
std
::
to_string
(
std
::
get
<
5
>
(
info
.
param
));
std
::
to_string
(
std
::
get
<
5
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
));
return
name
;
});
tests/cpp/operator/test_normalization.h
0 → 100644
View file @
ab3e5a92
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <cmath>
#include <cstring>
#include <memory>
#include <iomanip>
#include <iostream>
#include <random>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
namespace
test
{
namespace
{
enum
NormType
{
LayerNorm
,
RMSNorm
};
std
::
map
<
NormType
,
std
::
string
>
normToString
=
{
{
NormType
::
LayerNorm
,
"LayerNorm"
},
{
NormType
::
RMSNorm
,
"RmsNorm"
}
};
template
<
typename
InputType
>
void
compute_ref_stats
(
NormType
norm_type
,
const
InputType
*
data
,
float
*
mu
,
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
const
double
epsilon
){
using
compute_t
=
float
;
compute_t
current
,
m
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
sum
+=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
}
if
(
norm_type
==
LayerNorm
){
mu
[
i
]
=
sum
/
H
;
m
=
mu
[
i
];
}
else
{
m
=
0
;}
compute_t
sum_sq
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
template
<
typename
InputType
>
inline
auto
compute_gamma
(
InputType
gamma
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
,
const
bool
cudnn_zero_centered_gamma_in_weight_dtype
)
{
using
compute_t
=
float
;
// Zero-centered gamma in weight dtype is only supported in CuDNN backend currently
// Remove the use_cudnn check here when it is supported by both backends.
const
bool
zero_centered_gamma_in_weight_dtype
=
use_cudnn
&&
cudnn_zero_centered_gamma_in_weight_dtype
;
if
constexpr
(
std
::
is_same_v
<
InputType
,
fp8e5m2
>
||
std
::
is_same_v
<
InputType
,
fp8e4m3
>
){
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
else
{
if
(
zero_centered_gamma_in_weight_dtype
){
compute_t
g
=
static_cast
<
compute_t
>
(
0.
f
);
#ifndef __HIP_PLATFORM_AMD__
InputType
gi
=
gamma
;
if
(
zero_centered_gamma
)
{
gi
=
gi
+
static_cast
<
InputType
>
(
1.
f
);
}
g
=
static_cast
<
compute_t
>
(
gi
);
#else
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
#endif
return
g
;
}
else
{
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
);
if
(
zero_centered_gamma
)
{
g
+=
static_cast
<
compute_t
>
(
1.
f
);
}
return
g
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_output
(
NormType
norm_type
,
const
InputType
*
data
,
const
InputType
*
gamma
,
const
InputType
*
beta
,
OutputType
*
output
,
const
float
*
mu
,
const
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
float
*
amax
,
float
scale
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
,
const
bool
cudnn_zero_centered_gamma_in_weight_dtype
)
{
using
compute_t
=
float
;
compute_t
current_max
=
-
1e100
;
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
,
cudnn_zero_centered_gamma_in_weight_dtype
);
compute_t
tmp
;
if
(
norm_type
==
LayerNorm
)
{
tmp
=
(
current
-
mu
[
i
])
*
rsigma
[
i
]
*
g
+
static_cast
<
compute_t
>
(
beta
[
j
]);
}
else
{
// RMSNorm
tmp
=
current
*
rsigma
[
i
]
*
g
;
}
output
[
i
*
H
+
j
]
=
static_cast
<
OutputType
>
(
tmp
*
scale
);
current_max
=
fmaxf
(
current_max
,
fabsf
(
tmp
));
}
}
if
(
amax
)
{
*
amax
=
current_max
;
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_backward
(
const
NormType
norm_type
,
const
OutputType
*
output_grad
,
const
InputType
*
data
,
const
float
*
mu
,
const
float
*
rsigma
,
const
InputType
*
gamma
,
InputType
*
data_grad
,
InputType
*
gamma_grad
,
InputType
*
beta_grad
,
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
const
bool
use_cudnn
,
const
bool
cudnn_zero_centered_gamma_in_weight_dtype
)
{
using
compute_t
=
float
;
std
::
vector
<
compute_t
>
dgamma
(
H
,
0.
f
);
std
::
vector
<
compute_t
>
dbeta
(
H
,
0.
f
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
// Reductions
auto
local_mu
=
(
norm_type
==
LayerNorm
)
?
mu
[
i
]
:
0.
;
compute_t
mdy
=
0
,
mdyy
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
,
cudnn_zero_centered_gamma_in_weight_dtype
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
dgamma
[
j
]
+=
y
*
dz
;
if
(
norm_type
==
LayerNorm
)
{
dbeta
[
j
]
+=
dz
;
mdy
+=
dy
;
}
mdyy
+=
dy
*
y
;
}
mdy
/=
H
;
mdyy
/=
H
;
// Input grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
const
compute_t
x
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
const
compute_t
y
=
(
x
-
local_mu
)
*
rsigma
[
i
];
compute_t
g
=
compute_gamma
(
gamma
[
j
],
zero_centered_gamma
,
use_cudnn
,
cudnn_zero_centered_gamma_in_weight_dtype
);
const
compute_t
dz
=
static_cast
<
compute_t
>
(
output_grad
[
i
*
H
+
j
]);
const
compute_t
dy
=
g
*
dz
;
const
compute_t
dx
=
rsigma
[
i
]
*
(
dy
-
mdyy
*
y
-
mdy
);
data_grad
[
i
*
H
+
j
]
=
static_cast
<
InputType
>
(
dx
);
}
}
// Weight grads
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
gamma_grad
[
j
]
=
static_cast
<
InputType
>
(
dgamma
[
j
]);
if
(
norm_type
==
LayerNorm
)
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
beta_grad
[
j
]
=
static_cast
<
InputType
>
(
dbeta
[
j
]);
}
}
// namespace
}
// namespace test
tests/cpp/operator/test_normalization_mxfp8.cu
View file @
ab3e5a92
...
...
@@ -19,6 +19,7 @@
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"
#include "test_normalization.h"
using
namespace
transformer_engine
;
using
namespace
test
;
...
...
@@ -27,16 +28,6 @@ namespace {
using
fp8e8m0
=
byte
;
enum
NormType
{
LayerNorm
,
RMSNorm
};
std
::
map
<
NormType
,
std
::
string
>
normToString
=
{
{
NormType
::
LayerNorm
,
"LayerNorm"
},
{
NormType
::
RMSNorm
,
"RMSNorm"
}
};
template
<
typename
InputType
,
typename
ScaleType
,
typename
OutputType
>
void
dequantize_1x_kernel
(
InputType
*
input_ptr
,
ScaleType
*
scale_ptr
,
OutputType
*
output_ptr
,
size_t
rows
,
size_t
cols
,
size_t
scaling_mode_x
,
size_t
scaling_mode_y
){
...
...
@@ -110,69 +101,8 @@ void dequantize_2x(Tensor& input, Tensor& output, bool is_training)
32
,
1
);
}
template
<
typename
InputType
>
void
compute_ref_stats
(
NormType
norm_type
,
const
InputType
*
data
,
float
*
mu
,
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
const
double
epsilon
){
using
compute_t
=
float
;
#pragma omp parallel for proc_bind(spread)
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
compute_t
sum
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
sum
+=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
}
compute_t
m
;
if
(
norm_type
==
LayerNorm
){
mu
[
i
]
=
sum
/
H
;
m
=
mu
[
i
];
}
else
{
m
=
0
;}
compute_t
sum_sq
=
0
;
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
compute_ref_output
(
NormType
norm_type
,
const
InputType
*
data
,
const
InputType
*
gamma
,
const
InputType
*
beta
,
const
float
*
mu
,
const
float
*
rsigma
,
const
size_t
N
,
const
size_t
H
,
OutputType
*
output
,
const
bool
zero_centered_gamma
){
using
compute_t
=
float
;
#pragma omp parallel for proc_bind(spread)
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
H
;
++
j
)
{
compute_t
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
compute_t
g
=
static_cast
<
compute_t
>
(
gamma
[
j
]);
if
(
zero_centered_gamma
)
{
g
+=
1.0
;
}
compute_t
tmp
;
if
(
norm_type
==
LayerNorm
)
{
tmp
=
(
current
-
mu
[
i
])
*
rsigma
[
i
]
*
g
+
static_cast
<
compute_t
>
(
beta
[
j
]);
}
else
{
// RMSNorm
tmp
=
current
*
rsigma
[
i
]
*
g
;
}
output
[
i
*
H
+
j
]
=
tmp
;
}
}
}
template
<
typename
InputType
,
typename
OutputType
>
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
NormType
norm_type
,
bool
is_training
)
{
void
performTest
(
const
size_t
N
,
const
size_t
H
,
const
bool
zero_centered_gamma
,
NormType
norm_type
,
bool
is_training
,
const
bool
zero_centered_gamma_in_weight_dtype
)
{
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
0
);
...
...
@@ -199,6 +129,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
fillUniform
(
&
gamma
);
fillUniform
(
&
beta
);
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
true
);
}
else
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
// Forward kernel
float
epsilon
=
1e-5
;
if
(
norm_type
==
NormType
::
LayerNorm
){
...
...
@@ -224,6 +160,10 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
0
);
}
if
(
zero_centered_gamma_in_weight_dtype
)
{
nvte_enable_zero_centered_gamma_in_weight_dtype
(
false
);
}
Tensor
dequantized_output
(
"dequantized_output"
,
{
N
,
H
},
DType
::
kFloat32
,
true
,
true
);
dequantize_2x
<
OutputType
,
fp8e8m0
>
(
z
,
dequantized_output
,
is_training
);
...
...
@@ -250,11 +190,15 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
compute_ref_output
(
norm_type
,
input
.
rowwise_cpu_dptr
<
InputType
>
(),
gamma
.
rowwise_cpu_dptr
<
WeightType
>
(),
beta
.
rowwise_cpu_dptr
<
WeightType
>
(),
ref_output
.
get
(),
ref_mu_ptr
,
ref_rsigma_ptr
,
N
,
H
,
ref_output
.
get
(),
zero_centered_gamma
);
nullptr
,
// amax
1.
f
,
// scale
zero_centered_gamma
,
true
,
// CuDNN is the only MXFP8 backend currently
zero_centered_gamma_in_weight_dtype
);
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
...
...
@@ -302,7 +246,7 @@ class MxNormTestSuite : public ::testing::TestWithParam< std::tuple<NormType,
transformer_engine
::
DType
,
transformer_engine
::
DType
,
std
::
pair
<
size_t
,
size_t
>
,
bool
,
bool
>>
{};
bool
,
bool
,
bool
>>
{};
TEST_P
(
MxNormTestSuite
,
TestMxNorm
)
{
using
namespace
transformer_engine
;
...
...
@@ -314,10 +258,11 @@ TEST_P(MxNormTestSuite, TestMxNorm) {
const
auto
size
=
std
::
get
<
3
>
(
GetParam
());
const
bool
zero_centered_gamma
=
std
::
get
<
4
>
(
GetParam
());
const
bool
is_training
=
std
::
get
<
5
>
(
GetParam
());
const
bool
zero_centered_gamma_in_weight_dtype
=
std
::
get
<
6
>
(
GetParam
());
TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY
(
input_type
,
InputType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY
(
output_type
,
OutputType
,
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
is_training
);
performTest
<
InputType
,
OutputType
>
(
size
.
first
,
size
.
second
,
zero_centered_gamma
,
norm_type
,
is_training
,
zero_centered_gamma_in_weight_dtype
);
);
);
}
...
...
@@ -331,6 +276,7 @@ INSTANTIATE_TEST_SUITE_P(
::
testing
::
Values
(
DType
::
kFloat8E5M2
,
DType
::
kFloat8E4M3
),
::
testing
::
ValuesIn
(
test_cases
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
Values
(
true
,
false
),
::
testing
::
Values
(
true
,
false
)),
[](
const
testing
::
TestParamInfo
<
MxNormTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
normToString
.
at
(
std
::
get
<
0
>
(
info
.
param
))
+
"_"
+
...
...
@@ -339,6 +285,7 @@ INSTANTIATE_TEST_SUITE_P(
std
::
to_string
(
std
::
get
<
3
>
(
info
.
param
).
first
)
+
"X"
+
std
::
to_string
(
std
::
get
<
3
>
(
info
.
param
).
second
)
+
"X"
+
std
::
to_string
(
std
::
get
<
4
>
(
info
.
param
))
+
"out"
+
std
::
to_string
(
int
(
std
::
get
<
5
>
(
info
.
param
))
+
1
)
+
"x"
;
std
::
to_string
(
int
(
std
::
get
<
5
>
(
info
.
param
))
+
1
)
+
"x"
+
std
::
to_string
(
std
::
get
<
6
>
(
info
.
param
));
return
name
;
});
tests/cpp/test_common.cu
View file @
ab3e5a92
...
...
@@ -10,6 +10,7 @@
#include <algorithm>
#include <memory>
#include <random>
#include <iostream>
#include <cassert>
#include <cmath>
#include <string>
...
...
@@ -111,8 +112,8 @@ struct scale_inv_meta {
size_t
type_size
;
};
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>&
s
hape
)
{
return
{
shape
.
data
(),
s
hape
.
size
()
}
;
NVTEShape
convertShape
(
const
std
::
vector
<
size_t
>&
s
)
{
return
nvte_make_
shape
(
s
.
data
(),
s
.
size
()
)
;
}
std
::
pair
<
scale_inv_meta
,
scale_inv_meta
>
get_scales
(
const
NVTEShape
&
shape
,
...
...
@@ -134,27 +135,19 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
scale_inv_meta
ret_rowwise
,
ret_colwise
;
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
auto
block_alignment
=
std
::
vector
<
size_t
>
{
128ul
,
4ul
};
{
auto
alignment
=
block_alignment
[
0
];
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
1
];
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
{
auto
alignment
=
block_alignment
[
1
];
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
auto
scale_dim_0
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
32
)),
alignment
)
*
alignment
;
alignment
=
block_alignment
[
0
];
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
1
)),
alignment
)
*
alignment
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
ret_rowwise
.
type
=
DType
::
kFloat8E8M0
;
...
...
@@ -164,6 +157,58 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
return
{
ret_rowwise
,
ret_colwise
};
}
if
(
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
shape_vec
.
push_back
(
shape
.
data
[
i
]);
}
size_t
first_dim
=
first_dimension
(
shape_vec
);
size_t
last_dim
=
last_dimension
(
shape_vec
);
scale_inv_meta
ret_rowwise
,
ret_colwise
;
{
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
)),
4
)
*
4
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
{
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
)),
4
)
*
4
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_colwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type_size
=
sizeof
(
float
);
ret_colwise
.
type_size
=
sizeof
(
float
);
return
{
ret_rowwise
,
ret_colwise
};
}
if
(
scaling_mode
==
NVTE_BLOCK_SCALING_1D
)
{
std
::
vector
<
size_t
>
shape_vec
;
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
;
++
i
)
{
shape_vec
.
push_back
(
shape
.
data
[
i
]);
}
size_t
first_dim
=
first_dimension
(
shape_vec
);
size_t
last_dim
=
last_dimension
(
shape_vec
);
scale_inv_meta
ret_rowwise
,
ret_colwise
;
{
auto
scale_dim_0
=
DIVUP
(
last_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
first_dim
,
4
)
*
4
;
ret_rowwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
{
auto
scale_dim_0
=
DIVUP
(
first_dim
,
static_cast
<
size_t
>
(
128
));
auto
scale_dim_1
=
DIVUP
(
last_dim
,
4
)
*
4
;
ret_colwise
.
shape
=
{
scale_dim_0
,
scale_dim_1
};
}
ret_rowwise
.
type
=
DType
::
kFloat32
;
ret_colwise
.
type
=
DType
::
kFloat32
;
ret_rowwise
.
type_size
=
sizeof
(
float
);
ret_colwise
.
type_size
=
sizeof
(
float
);
return
{
ret_rowwise
,
ret_colwise
};
}
NVTE_ERROR
(
"Invalid scaling mode!"
);
}
...
...
@@ -195,10 +240,10 @@ Tensor::Tensor(const std::string& name,
std
::
vector
<
size_t
>
normalized_shape_v
=
{
product
(
shape
,
0
,
shape
.
ndim
-
1
),
shape
.
data
[
shape
.
ndim
-
1
]};
NVTEShape
normalized_shape
=
convertShape
(
normalized_shape_v
);
NVTEShape
columnwise_shape
{
nullptr
,
0
};
NVTEShape
columnwise_shape
=
{
};
std
::
vector
<
size_t
>
columnwise_shape_vec
;
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
||
scaling_mode
==
NVTE_BLOCK_SCALING_1D
||
scaling_mode
==
NVTE_BLOCK_SCALING_2D
)
{
// Transpose when tensor scaling
columnwise_shape_vec
.
emplace_back
(
shape
.
data
[
shape
.
ndim
-
1
]);
for
(
size_t
i
=
0
;
i
<
shape
.
ndim
-
1
;
++
i
)
{
...
...
@@ -212,8 +257,7 @@ Tensor::Tensor(const std::string& name,
}
if
(
columnwise
)
{
columnwise_shape
.
data
=
columnwise_shape_vec
.
data
();
columnwise_shape
.
ndim
=
columnwise_shape_vec
.
size
();
columnwise_shape
=
nvte_make_shape
(
columnwise_shape_vec
.
data
(),
columnwise_shape_vec
.
size
());
}
tensor_
=
TensorWrapper
(
scaling_mode
);
...
...
@@ -259,25 +303,27 @@ Tensor::Tensor(const std::string& name,
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
sizeof
(
float
),
0
);
}
}
else
{
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
normalized_shape
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
normalized_shape
,
tensor_
.
scaling_mode
());
auto
rowwise_scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
auto
columnwise_scale_size
=
product
(
colwise_scale_meta
.
shape
)
*
colwise_scale_meta
.
type_size
;
auto
scale_shape
=
rowwise_scale_meta
.
shape
;
auto
columnwise_scale_shape
=
colwise_scale_meta
.
shape
;
if
(
rowwise
)
{
cudaMalloc
((
void
**
)
&
rowwise_scale_inv
,
rowwise_scale_size
);
// NOLINT(*)
cudaMalloc
((
void
**
)
&
rowwise_scale_inv
,
rowwise_scale_size
);
// NOLINT(*)
cudaMemset
(
rowwise_scale_inv
,
0
,
rowwise_scale_size
);
rowwise_scale_inv_cpu_data_
=
std
::
make_unique
<
unsigned
char
[]
>
(
rowwise_scale_size
);
std
::
fill_n
(
rowwise_scale_inv_cpu_data_
.
get
(),
rowwise_scale_size
,
0
);
tensor_
.
set_rowwise_scale_inv
(
rowwise_scale_inv
,
DType
::
kFloat8E8M0
,
scale_shape
);
auto
scale_dtype
=
rowwise_scale_meta
.
type
;
tensor_
.
set_rowwise_scale_inv
(
rowwise_scale_inv
,
scale_dtype
,
scale_shape
);
}
if
(
columnwise
)
{
cudaMalloc
((
void
**
)
&
columnwise_scale_inv
,
columnwise_scale_size
);
// NOLINT(*)
cudaMemset
(
columnwise_scale_inv
,
0
,
columnwise_scale_size
);
columnwise_scale_inv_cpu_data_
=
std
::
make_unique
<
unsigned
char
[]
>
(
columnwise_scale_size
);
std
::
fill_n
(
columnwise_scale_inv_cpu_data_
.
get
(),
columnwise_scale_size
,
0
);
tensor_
.
set_columnwise_scale_inv
(
columnwise_scale_inv
,
DType
::
kFloat8E8M0
,
columnwise_scale_shape
);
auto
scale_dtype
=
colwise_scale_meta
.
type
;
tensor_
.
set_columnwise_scale_inv
(
columnwise_scale_inv
,
scale_dtype
,
columnwise_scale_shape
);
}
}
}
...
...
@@ -311,7 +357,8 @@ void Tensor::to_cpu() const {
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
cudaMemcpy
(
rowwise_scale_inv_cpu_data_
.
get
(),
...
...
@@ -349,7 +396,8 @@ void Tensor::from_cpu() const {
cudaMemcpy
(
tensor_
.
scale
(),
scale_cpu_data_
.
get
(),
sizeof
(
float
),
cudaMemcpyHostToDevice
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
s
,
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
scale_size
=
product
(
rowwise_scale_meta
.
shape
)
*
rowwise_scale_meta
.
type_size
;
cudaMemcpy
(
tensor_
.
get_rowwise_scale_inv
().
data_ptr
,
...
...
@@ -368,7 +416,7 @@ void Tensor::from_cpu() const {
void
Tensor
::
set_scale
(
float
scale
)
{
if
(
isFp8Type
(
dtype
()))
{
NVTE_CHECK
(
scale_cpu_data_
);
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
{
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
)
{
*
scale_cpu_data_
=
scale
;
from_cpu
();
}
...
...
@@ -383,27 +431,29 @@ void Tensor::set_scale_inv(float scale_inv) {
if
(
columnwise_
)
{
NVTE_CHECK
(
columnwise_scale_inv_cpu_data_
);
}
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
tensor_
.
shape
(),
tensor_
.
scaling_mode
());
auto
[
rowwise_scale_meta
,
colwise_scale_meta
]
=
get_scales
(
tensor_
.
shape
(),
tensor_
.
scaling_mode
());
if
(
rowwise_
)
{
auto
num_scales
=
product
(
rowwise_scale_meta
.
shape
);
if
(
num_scales
==
1
){
if
(
num_scales
==
1
)
{
rowwise_cpu_scale_inv_ptr
<
float
>
()[
0
]
=
scale_inv
;
}
else
{
}
else
{
std
::
uniform_int_distribution
<
uint8_t
>
dis
(
0
,
127
);
auto
*
scale_inv_ptr
=
rowwise_cpu_scale_inv_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
){
auto
*
scale_inv_ptr
=
rowwise_cpu_scale_inv_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
)
{
scale_inv_ptr
[
i
]
=
dis
(
gen_
);
}
}
}
if
(
columnwise_
)
{
auto
num_scales
=
product
(
colwise_scale_meta
.
shape
);
if
(
num_scales
==
1
){
if
(
num_scales
==
1
)
{
columnwise_cpu_scale_inv_ptr
<
float
>
()[
0
]
=
scale_inv
;
}
else
{
}
else
{
std
::
uniform_int_distribution
<
uint8_t
>
dis
(
0
,
127
);
auto
*
scale_inv_ptr
=
columnwise_cpu_scale_inv_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
){
auto
*
scale_inv_ptr
=
columnwise_cpu_scale_inv_ptr
<
uint8_t
>
();
for
(
size_t
i
=
0
;
i
<
num_scales
;
i
++
)
{
scale_inv_ptr
[
i
]
=
dis
(
gen_
);
}
}
...
...
@@ -413,23 +463,20 @@ void Tensor::set_scale_inv(float scale_inv) {
}
void
Tensor
::
shareFP8Meta
(
const
Tensor
&
other
)
{
if
(
isFp8Type
(
dtype
())
&&
isFp8Type
(
other
.
dtype
()))
{
if
(
isFp8Type
(
dtype
())
&&
isFp8Type
(
other
.
dtype
()))
{
auto
new_tensor
=
TensorWrapper
(
other
.
tensor_
.
scaling_mode
());
auto
my_rowwise_data
=
tensor_
.
get_rowwise_data
();
new_tensor
.
set_rowwise_data
(
my_rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
my_rowwise_data
.
dtype
),
new_tensor
.
set_rowwise_data
(
my_rowwise_data
.
data_ptr
,
static_cast
<
DType
>
(
my_rowwise_data
.
dtype
),
my_rowwise_data
.
shape
);
auto
my_columnwise_data
=
tensor_
.
get_columnwise_data
();
new_tensor
.
set_columnwise_data
(
my_columnwise_data
.
data_ptr
,
static_cast
<
DType
>
(
my_columnwise_data
.
dtype
),
my_columnwise_data
.
shape
);
auto
other_amax
=
other
.
tensor_
.
get_amax
();
new_tensor
.
set_amax
(
other_amax
.
data_ptr
,
static_cast
<
DType
>
(
other_amax
.
dtype
),
new_tensor
.
set_amax
(
other_amax
.
data_ptr
,
static_cast
<
DType
>
(
other_amax
.
dtype
),
other_amax
.
shape
);
auto
other_scale
=
other
.
tensor_
.
get_scale
();
new_tensor
.
set_scale
(
other_scale
.
data_ptr
,
static_cast
<
DType
>
(
other_scale
.
dtype
),
new_tensor
.
set_scale
(
other_scale
.
data_ptr
,
static_cast
<
DType
>
(
other_scale
.
dtype
),
other_scale
.
shape
);
auto
other_row_scale_inv
=
other
.
tensor_
.
get_rowwise_scale_inv
();
new_tensor
.
set_rowwise_scale_inv
(
other_row_scale_inv
.
data_ptr
,
...
...
@@ -460,9 +507,7 @@ std::string to_string(const std::vector<T> &v) {
std
::
vector
<
size_t
>
unravel
(
const
size_t
i
,
const
NVTEShape
&
shape
)
{
std
::
vector
<
size_t
>
ret
;
size_t
current_i
=
i
;
for
(
size_t
current
=
shape
.
ndim
-
1
;
current
>
0
;
--
current
)
{
for
(
size_t
current
=
shape
.
ndim
-
1
;
current
>
0
;
--
current
)
{
ret
.
push_back
(
current_i
%
shape
.
data
[
current
]);
current_i
/=
shape
.
data
[
current
];
}
...
...
@@ -750,7 +795,7 @@ void fillCase_special(Tensor *t) {
});
}
else
{
double
minAbs
=
-
2.0
;
double
maxAbs
=
1.0
;
double
maxAbs
=
1.0
;
if
constexpr
(
Case
!=
InputsFillCase
::
uniform
)
{
minAbs
=
Quantized_Limits
<
InputEncoding
>::
ranges
[
Case
];
maxAbs
=
Quantized_Limits
<
InputEncoding
>::
ranges
[
Case
+
1
];
...
...
@@ -809,14 +854,13 @@ void setRandomScaleInv(Tensor *t) {
}
bool
isFp8Type
(
DType
type
)
{
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
||
type
==
DType
::
kFloat8E8M0
;
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
||
type
==
DType
::
kFloat8E8M0
;
}
int32_t
getDeviceComputeCapability
()
{
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
return
10
*
deviceProp
.
major
+
deviceProp
.
minor
;
int32_t
getDeviceComputeCapability
()
{
cudaDeviceProp
deviceProp
;
cudaGetDeviceProperties
(
&
deviceProp
,
0
);
return
10
*
deviceProp
.
major
+
deviceProp
.
minor
;
}
size_t
first_dimension
(
const
std
::
vector
<
size_t
>
&
shape
)
{
...
...
tests/cpp/test_common.h
View file @
ab3e5a92
...
...
@@ -121,7 +121,7 @@ class Tensor {
const
bool
rowwise
=
true
,
const
bool
columnwise
=
false
,
const
NVTEScalingMode
&
mode
=
NVTE_DELAYED_TENSOR_SCALING
)
:
Tensor
(
name
,
NVTES
hape
{
shape
.
data
(),
shape
.
size
()
}
,
type
,
rowwise
,
columnwise
,
mode
)
{}
Tensor
(
name
,
nvte_make_s
hape
(
shape
.
data
(),
shape
.
size
()
)
,
type
,
rowwise
,
columnwise
,
mode
)
{}
Tensor
()
{}
...
...
@@ -148,25 +148,19 @@ class Tensor {
if
(
scale_inv
!=
nullptr
)
{
cudaFree
(
scale_inv
);
}
if
(
columnwise_data_ptr
!=
nullptr
){
if
(
columnwise_data_ptr
!=
nullptr
)
{
cudaFree
(
columnwise_data_ptr
);
}
if
(
columnwise_scale_inv
!=
nullptr
){
if
(
columnwise_scale_inv
!=
nullptr
)
{
cudaFree
(
columnwise_scale_inv
);
}
}
NVTETensor
data
()
const
noexcept
{
return
tensor_
.
data
();
}
NVTETensor
data
()
const
noexcept
{
return
tensor_
.
data
();
}
NVTEShape
rowwise_shape
()
const
noexcept
{
return
tensor_
.
get_rowwise_data
().
shape
;
}
NVTEShape
rowwise_shape
()
const
noexcept
{
return
tensor_
.
get_rowwise_data
().
shape
;
}
NVTEShape
columnwise_shape
()
const
noexcept
{
return
tensor_
.
get_columnwise_data
().
shape
;
}
NVTEShape
columnwise_shape
()
const
noexcept
{
return
tensor_
.
get_columnwise_data
().
shape
;
}
NVTEShape
rowwise_scale_inv_shape
()
const
{
NVTE_CHECK
(
rowwise_
,
"Tensor does not have rowwise data!"
);
...
...
@@ -233,6 +227,8 @@ class Tensor {
T
*
rowwise_cpu_scale_inv_ptr
(){
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
){
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
if
(
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_1D
||
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_2D
)
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kByte
,
"Invalid type!"
);
}
...
...
@@ -244,6 +240,8 @@ class Tensor {
T
*
columnwise_cpu_scale_inv_ptr
(){
if
(
tensor_
.
scaling_mode
()
==
NVTE_DELAYED_TENSOR_SCALING
){
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
if
(
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_1D
||
tensor_
.
scaling_mode
()
==
NVTE_BLOCK_SCALING_2D
)
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kFloat32
,
"Invalid type!"
);
}
else
{
NVTE_CHECK
(
TypeInfo
<
T
>::
dtype
==
DType
::
kByte
,
"Invalid type!"
);
}
...
...
@@ -475,6 +473,7 @@ extern std::vector<DType> all_fp_types;
bool
isFp8Type
(
DType
type
);
int32_t
getDeviceComputeCapability
();
constexpr
int32_t
hopperComputeCapability
=
90
;
constexpr
int32_t
blackwellComputeCapability
=
100
;
}
// namespace test
...
...
tests/jax/pytest.ini
View file @
ab3e5a92
...
...
@@ -25,3 +25,5 @@ filterwarnings=
ignore:jax.experimental.maps
and
.*
are
deprecated.*:DeprecationWarning
ignore:The
host_callback
APIs
are
deprecated
.*:DeprecationWarning
ignore:Scan
loop
is
disabled
for
fused
ring
attention.*:UserWarning
ignore:jax.extend.ffi.register_ffi_target
is
deprecated
ignore:jax.extend.ffi.ffi_lowering
is
deprecated
tests/jax/test_custom_call_compute.py
View file @
ab3e5a92
...
...
@@ -29,7 +29,7 @@ from transformer_engine.jax.quantize import (
ScaledTensor
,
ScalingMode
,
QuantizerFactory
,
Quantize
Axis
,
Quantize
Layout
,
)
from
transformer_engine.jax.quantize
import
helper
from
transformer_engine.jax.activation
import
activation
...
...
@@ -48,21 +48,21 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES
=
[(
256
,
128
),
(
128
,
256
)]
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float32
]
is_fp8_supported
,
reason
=
helper
.
is_fp8_available
()
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
helper
.
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
supported_scaling_modes
=
[]
""" Find supported scaling modes"""
if
is_fp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
DELAYED_TENSOR_SCALING
)
if
is_mxfp8_supported
:
supported_scaling_modes
.
append
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
supported_scaling_modes
.
append
(
ScalingMode
.
MXFP8_1D_SCALING
)
def
is_shape_supported_by_mxfp8
(
input_shape
):
try
:
if
isinstance
(
input_shape
,
type
(
pytest
.
param
(
0
))):
input_shape
=
input_shape
.
values
[
0
]
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
.
get_scale_shape_2x
(
input_shape
)
ScalingMode
.
MXFP8_1D_SCALING
.
get_scale_shape_2x
(
input_shape
)
return
True
except
:
# get_scale_shapes will raise an exception if the shape is not supported
...
...
@@ -82,8 +82,9 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor):
def
assert_dequantized_scaled_tensor
(
a
:
ScaledTensor
,
b
:
jnp
.
ndarray
):
if
isinstance
(
a
,
ScaledTensor1x
):
if
a
.
layout
==
"T"
:
b_transpose
=
jnp
.
transpose
(
b
,
(
-
1
,
*
range
(
b
.
ndim
-
1
)))
if
a
.
data_layout
==
"T"
:
flatten_axis
=
a
.
data
.
ndim
-
a
.
flatten_axis
b_transpose
=
jnp
.
transpose
(
b
,
(
*
range
(
flatten_axis
,
b
.
ndim
),
*
range
(
flatten_axis
)))
assert_allclose
(
a
.
dequantize
(),
b_transpose
,
dtype
=
a
.
data
.
dtype
)
else
:
assert_allclose
(
a
.
dequantize
(),
b
,
dtype
=
a
.
data
.
dtype
)
...
...
@@ -141,7 +142,8 @@ class TestActivation:
def
test_act_grad
(
self
,
shape
,
activation_type
):
key
=
jax
.
random
.
PRNGKey
(
0
)
x
=
jax
.
random
.
uniform
(
key
,
shape
,
jnp
.
float32
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
value_n_grad_primitive_func
=
jit
(
value_and_grad
(
self
.
primitive_func
,
(
0
,)),
static_argnums
=
(
1
,)
...
...
@@ -159,7 +161,8 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
def
test_act_grad_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
value_n_grad_primitive_func
=
jit
(
...
...
@@ -167,9 +170,9 @@ class TestActivation:
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
output_type
,
q_
axis
=
Quantize
Axis
.
ROWWISE
,
q_
layout
=
Quantize
Layout
.
ROWWISE
,
)
prim_out
,
(
prim_grad
,)
=
value_n_grad_primitive_func
(
x
,
activation_type
,
quantizer
)
...
...
@@ -182,19 +185,22 @@ class TestActivation:
@
pytest_parametrize_wrapper
(
"shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_act_forward_with_delayed_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_
axis
self
,
random_inputs
,
activation_type
,
output_type
,
q_
layout
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
output_type
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
te_output
=
tex
.
act_lu
(
x
,
activation_type
,
te_quantizer
)
...
...
@@ -203,19 +209,21 @@ class TestActivation:
assert_bitwise_scaled_tensors
(
te_output
,
jax_output
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"shape"
,
[(
128
,
128
)])
@
pytest_parametrize_wrapper
(
"shape"
,
[(
2
,
64
,
1
,
256
)])
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"output_type"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_act_forward_with_block_scaling_fp8
(
self
,
random_inputs
,
activation_type
,
output_type
,
q_
axis
self
,
random_inputs
,
activation_type
,
output_type
,
q_
layout
):
x
=
random_inputs
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
self
.
activation_type
=
activation_type
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
,
q_dtype
=
output_type
,
q_
axis
=
q_axis
scaling_mode
=
ScalingMode
.
MXFP8_1D_SCALING
,
q_dtype
=
output_type
,
q_
layout
=
q_layout
)
output
=
tex
.
act_lu
(
x
,
activation_type
,
quantizer
)
...
...
@@ -324,9 +332,11 @@ class TestNorm:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_norm_grad_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
axis
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
layout
):
"""
Test transformer_engine.jax.layernorm.layernorm
...
...
@@ -335,7 +345,9 @@ class TestNorm:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_axis
=
q_axis
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_dtype
=
out_dtype
,
q_layout
=
q_layout
,
)
self
.
_test_norm_grad
(
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
quantizer
...
...
@@ -351,7 +363,7 @@ class TestNorm:
inp_dtype
,
out_dtype
,
scaling_mode
,
q_
axis
,
q_
layout
,
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
3
)
...
...
@@ -363,7 +375,7 @@ class TestNorm:
gamma
=
jnp
.
asarray
(
gamma
,
inp_dtype
)
quantizer
,
ref_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_
axis
=
q_axis
n_quantizers
=
2
,
scaling_mode
=
scaling_mode
,
q_dtype
=
out_dtype
,
q_
layout
=
q_layout
)
if
norm_type
==
"layernorm"
:
beta
=
jax
.
random
.
uniform
(
subkeys
[
2
],
(
hidden
,),
jnp
.
float32
,
-
1
,
1
)
...
...
@@ -391,9 +403,11 @@ class TestNorm:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
# No Norm FWD E5M2 in TE backend
@
pytest_parametrize_wrapper
(
"out_dtype"
,
[
jnp
.
float8_e4m3fn
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_norm_forward_with_delayed_scaling_fp8
(
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
axis
self
,
n
,
hidden
,
norm_type
,
zero_centered_gamma
,
epsilon
,
inp_dtype
,
out_dtype
,
q_
layout
):
if
norm_type
==
"rmsnorm"
and
zero_centered_gamma
is
True
:
pytest
.
skip
(
"RMSNorm and zero_centered_gamma is not supported!"
)
...
...
@@ -406,8 +420,8 @@ class TestNorm:
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
q_
axis
=
q_axis
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
q_
layout
=
q_layout
,
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
...
...
@@ -423,8 +437,8 @@ class TestNorm:
epsilon
=
epsilon
,
inp_dtype
=
inp_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
,
q_
axis
=
Quantize
Axis
.
ROWWISE_COLWISE
,
scaling_mode
=
ScalingMode
.
MXFP8_1D_SCALING
,
q_
layout
=
Quantize
Layout
.
ROWWISE_COLWISE
,
)
...
...
@@ -434,14 +448,14 @@ QUANTIZE_OUTPUT_DTYPES = {
}
ALL_QUANTIZE_TEST_SHAPES
=
[
(
128
,
128
),
(
4
,
256
,
51
2
),
(
32
,
64
),
(
2
,
64
,
3
2
),
]
QUANTIZE_TEST_SHAPES
=
{
"L0"
:
[
(
256
,
128
),
(
64
,
16
,
2
,
256
),
(
32
,
256
,
128
),
(
64
,
32
,
3
2
,
256
),
],
"L2"
:
ALL_QUANTIZE_TEST_SHAPES
,
}
...
...
@@ -457,48 +471,52 @@ QUANTIZATION_INPUT_DTYPE = {
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
@
pytest_parametrize_wrapper
(
"q_
axis
"
,
[
Quantize
Axis
.
ROWWISE
,
Quantize
Axis
.
COLWISE
,
Quantize
Axis
.
ROWWISE_COLWISE
]
"q_
layout
"
,
[
Quantize
Layout
.
ROWWISE
,
Quantize
Layout
.
COLWISE
,
Quantize
Layout
.
ROWWISE_COLWISE
]
)
class
TestQuantize
:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def
test_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
def
test_qdq
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_
layout
,
flatten_
axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer
=
QuantizerFactory
.
create
(
scaling_mode
=
scaling_mode
,
q_dtype
=
q_dtype
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
x
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
scaled_tensor
=
quantizer
.
quantize
(
x
)
scaled_tensor
=
quantizer
.
quantize
(
x
,
flatten_axis
=
flatten_axis
)
assert_dequantized_scaled_tensor
(
scaled_tensor
,
x
)
def
test_quantize_bitwise
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_axis
):
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
input_shape
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
def
test_quantize_bitwise
(
self
,
in_dtype
,
input_shape
,
q_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
):
key
=
jax
.
random
.
PRNGKey
(
0
)
if
flatten_axis
==
-
2
:
input_shape
=
input_shape
[:
-
1
]
+
(
2
,)
+
input_shape
[
-
1
:]
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
te_quantizer
,
jax_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
n_quantizers
=
2
,
q_dtype
=
q_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
)
jax_output
=
_jax_quantize
(
input
,
quantizer
=
jax_quantizer
)
jax_output
=
_jax_quantize
(
input
,
quantizer
=
jax_quantizer
,
flatten_axis
=
flatten_axis
)
te_output
=
tex
.
quantize
(
input
,
quantizer
=
te_quantizer
)
assert_bitwise_scaled_tensors
(
jax
_output
,
te
_output
)
te_output
=
tex
.
quantize
(
input
,
quantizer
=
te_quantizer
,
flatten_axis
=
flatten_axis
)
assert_bitwise_scaled_tensors
(
te
_output
,
jax
_output
)
@
pytest_parametrize_wrapper
(
"in_dtype"
,
QUANTIZATION_INPUT_DTYPE
)
...
...
@@ -508,10 +526,14 @@ class TestFusedQuantize:
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
QUANTIZE_TEST_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
ROWWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_axis
):
transpose_axis
=
-
1
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
ROWWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
@
pytest_parametrize_wrapper
(
"flatten_axis"
,
[
-
1
,
-
2
])
def
test_quantize_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
):
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
and
not
is_shape_supported_by_mxfp8
(
input_shape
):
pytest
.
skip
(
f
"Input shape
{
input_shape
}
is not supported by MXFP8"
)
...
...
@@ -520,35 +542,37 @@ class TestFusedQuantize:
input
=
jax
.
random
.
uniform
(
key
,
input_shape
,
in_dtype
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
)
te_output
,
te_dbias
=
jit
(
lambda
input
:
tex
.
quantize_dbias
(
input
,
quantizer
=
te_quantizer
))(
input
)
te_output
,
te_dbias
=
jit
(
lambda
input
:
tex
.
quantize_dbias
(
input
,
quantizer
=
te_quantizer
,
flatten_axis
=
flatten_axis
)
)(
input
)
jax_output
,
jax_dbias
=
jit
(
lambda
input
:
_jax_quantize_dbias
(
input
,
quantizer
=
jax_quantizer
,
input
,
quantizer
=
jax_quantizer
,
flatten_axis
=
flatten_axis
)
)(
input
)
assert_bitwise_scaled_tensors
(
jax
_output
,
te
_output
)
assert_bitwise_scaled_tensors
(
te
_output
,
jax
_output
)
assert_allclose
(
jax
_dbias
,
te
_dbias
)
assert_allclose
(
te
_dbias
,
jax
_dbias
)
def
_test_quantize_dact_dbias
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
activation_type
,
is_dbias
,
q_
axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
scaling_mode
,
activation_type
,
is_dbias
,
q_
layout
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
input_shape
,
in_dtype
,
-
1
,
1
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
1
)
x
=
jnp
.
expand_dims
(
x
,
axis
=-
2
)
x
=
jnp
.
repeat
(
x
,
len
(
activation_type
),
axis
=-
2
)
dz
=
jax
.
random
.
uniform
(
subkeys
[
1
],
input_shape
,
in_dtype
,
-
1
,
1
)
jax_quantizer
,
te_quantizer
=
QuantizerFactory
.
create
(
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
axis
=
q_axis
n_quantizers
=
2
,
q_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_
layout
=
q_layout
)
is_casted_output
=
te_quantizer
is
not
None
...
...
@@ -573,12 +597,12 @@ class TestFusedQuantize:
)(
dz
,
x
)
if
is_casted_output
:
assert_bitwise_scaled_tensors
(
jax
_output
,
te
_output
)
assert_bitwise_scaled_tensors
(
te
_output
,
jax
_output
)
else
:
assert_allclose
(
jax
_output
,
te
_output
)
assert_allclose
(
te
_output
,
jax
_output
)
if
is_dbias
:
assert_allclose
(
jax
_dbias
,
te
_dbias
)
assert_allclose
(
te
_dbias
,
jax
_dbias
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
ACTIVATION_TYPES
)
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
...
...
@@ -594,10 +618,10 @@ class TestFusedQuantize:
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
in_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
NO_SCALING
,
scaling_mode
=
ScalingMode
.
NO_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_
axis
=
Quantize
Axis
.
ROWWISE
,
q_
layout
=
Quantize
Layout
.
ROWWISE
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
...
...
@@ -605,18 +629,20 @@ class TestFusedQuantize:
@
pytest_parametrize_wrapper
(
"input_shape"
,
ALL_ACTIVATION_SHAPES
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_quantize_dact_dbias_delayed_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
layout
):
self
.
_test_quantize_dact_dbias
(
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
,
scaling_mode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
@
pytest
.
mark
.
skipif
(
not
is_mxfp8_supported
,
reason
=
reason
)
...
...
@@ -626,9 +652,11 @@ class TestFusedQuantize:
)
@
pytest_parametrize_wrapper
(
"out_dtype"
,
QUANTIZE_OUTPUT_DTYPES
)
@
pytest_parametrize_wrapper
(
"is_dbias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"q_axis"
,
[
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
])
@
pytest_parametrize_wrapper
(
"q_layout"
,
[
QuantizeLayout
.
COLWISE
,
QuantizeLayout
.
ROWWISE_COLWISE
]
)
def
test_quantize_dact_dbias_mxfp8_scaling
(
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
axis
self
,
in_dtype
,
input_shape
,
out_dtype
,
activation_type
,
is_dbias
,
q_
layout
):
if
reduce
(
operator
.
mul
,
input_shape
[:
-
1
])
%
128
!=
0
or
input_shape
[
-
1
]
%
128
!=
0
:
# TODO(Jeremy): Remove this if pulling in newer TE branch supports non-full-tile shapes.
...
...
@@ -642,78 +670,78 @@ class TestFusedQuantize:
in_dtype
=
in_dtype
,
input_shape
=
input_shape
,
out_dtype
=
out_dtype
,
scaling_mode
=
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
,
scaling_mode
=
ScalingMode
.
MXFP8_1D_SCALING
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
q_
axis
=
q_axis
,
q_
layout
=
q_layout
,
)
class
TestDense
:
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
layout
):
if
layout
[
0
]
==
"T"
:
def
_ref_gemm_with_jnp_dot
(
self
,
a
,
b
,
data_
layout
):
if
data_
layout
[
0
]
==
"T"
:
a
=
jnp
.
swapaxes
(
a
,
-
1
,
-
2
)
if
layout
[
1
]
==
"T"
:
if
data_
layout
[
1
]
==
"T"
:
b
=
jnp
.
swapaxes
(
b
,
-
1
,
-
2
)
return
jnp
.
dot
(
a
,
b
)
def
_generate_gemm_input
(
self
,
m
,
n
,
k
,
layout
):
def
_generate_gemm_input
(
self
,
m
,
n
,
k
,
data_
layout
):
key
=
jax
.
random
.
PRNGKey
(
0
)
subkeys
=
jax
.
random
.
split
(
key
,
2
)
x
=
jax
.
random
.
uniform
(
subkeys
[
0
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
(
m
if
data_
layout
[
0
]
==
"N"
else
k
,
k
if
data_
layout
[
0
]
==
"N"
else
m
),
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
k
)
w
=
jax
.
random
.
uniform
(
subkeys
[
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
(
k
if
data_
layout
[
1
]
==
"N"
else
n
,
n
if
data_
layout
[
1
]
==
"N"
else
k
),
dtype
=
jnp
.
bfloat16
,
)
/
jnp
.
sqrt
(
n
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
lhs_contracting_dim
=
(
1
,)
if
data_
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
data_
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
return
(
x
,
w
,
contracting_dims
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"
data_
layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_bf16
(
self
,
m
,
n
,
k
,
data_
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest_parametrize_wrapper
(
"layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
@
pytest_parametrize_wrapper
(
"
data_
layout"
,
[
"TN"
,
"NT"
,
"NN"
,
"TT"
])
def
test_gemm_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
,
data_
layout
):
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
quantizer_set
=
QuantizerFactory
.
create_set
(
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
False
)
primitive_out
=
tex
.
gemm
(
x
,
w
,
contracting_dims
=
contracting_dims
,
quantizer_set
=
quantizer_set
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
ref_out
=
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
def
test_dense_grad_bf16
(
self
,
m
,
n
,
k
):
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
data_
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
def
primitive_func
(
x
,
w
,
contracting_dims
):
primitive_out
=
dense
(
x
,
w
,
contracting_dims
=
contracting_dims
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
w
,
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
))
def
ref_func
(
x
,
w
,
data_
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
))
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
))
...
...
@@ -722,19 +750,19 @@ class TestDense:
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
)
=
value_n_grad_primitive_func
(
x
,
w
,
contracting_dims
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
layout
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
data_
layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
jnp
.
bfloat16
)
assert_allclose
(
primitive_w_grad
,
ref_w_grad
,
dtype
=
jnp
.
bfloat16
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest_parametrize_wrapper
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest_parametrize_wrapper
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest_parametrize_wrapper
(
"scaling_mode"
,
supported_scaling_modes
)
def
test_dense_grad_fp8
(
self
,
m
,
n
,
k
,
q_dtype
,
scaling_mode
):
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
layout
)
data_
layout
=
"NN"
x
,
w
,
contracting_dims
=
self
.
_generate_gemm_input
(
m
,
n
,
k
,
data_
layout
)
key
=
jax
.
random
.
PRNGKey
(
1
)
bias
=
jax
.
random
.
uniform
(
key
,
n
,
dtype
=
jnp
.
bfloat16
)
...
...
@@ -745,9 +773,9 @@ class TestDense:
)
return
jnp
.
mean
(
primitive_out
)
def
ref_func
(
x
,
w
,
bias
,
layout
):
def
ref_func
(
x
,
w
,
bias
,
data_
layout
):
return
jnp
.
mean
(
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
layout
)
+
jnp
.
expand_dims
(
bias
,
axis
=
0
)
self
.
_ref_gemm_with_jnp_dot
(
x
,
w
,
data_
layout
)
+
jnp
.
expand_dims
(
bias
,
axis
=
0
)
)
value_n_grad_primitive_func
=
value_and_grad
(
primitive_func
,
(
0
,
1
,
2
))
...
...
@@ -757,13 +785,15 @@ class TestDense:
scaling_mode
=
scaling_mode
,
fwd_dtype
=
q_dtype
,
bwd_dtype
=
q_dtype
,
is_2x2x
=
True
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
primitive_out
,
(
primitive_x_grad
,
primitive_w_grad
,
primitive_bias_grad
)
=
(
value_n_grad_primitive_func
(
x
,
w
,
bias
,
contracting_dims
,
quantizer_set
)
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_bias_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
bias
,
layout
)
ref_out
,
(
ref_x_grad
,
ref_w_grad
,
ref_bias_grad
)
=
value_n_grad_ref_func
(
x
,
w
,
bias
,
data_layout
)
assert_allclose
(
primitive_out
,
ref_out
,
dtype
=
q_dtype
)
assert_allclose
(
primitive_x_grad
,
ref_x_grad
,
dtype
=
q_dtype
)
...
...
@@ -791,7 +821,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan
class
TestFusedDense
:
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
128
)])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
@
pytest
.
mark
.
parametrize
(
"norm_type"
,
[
"layernorm"
,
"rmsnorm"
])
...
...
@@ -800,7 +830,7 @@ class TestFusedDense:
Test layernorm_dense VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
...
...
@@ -856,7 +886,7 @@ class TestFusedDense:
x
,
w
,
gamma
,
beta
)
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
...
...
@@ -873,7 +903,7 @@ class TestFusedDense:
assert_allclose
(
prim_beta_grad
,
ref_beta_grad
,
dtype
=
q_dtype
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
512
,
128
,
256
)])
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
[(
64
,
32
,
64
)])
@
pytest
.
mark
.
parametrize
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest
.
mark
.
parametrize
(
"q_dtype"
,
[
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"scaling_mode"
,
supported_scaling_modes
)
...
...
@@ -886,7 +916,7 @@ class TestFusedDense:
Test layernorm_mlp VJP Rule
"""
# No Norm FWD E5M2 in TE backend
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
q_dtype
==
jnp
.
float8_e5m2
and
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
pytest
.
skip
(
"E5M2 is not supported in normalization with TE Backend!"
)
# zero_centered_gamma is already tested in TestNorm
...
...
@@ -898,13 +928,13 @@ class TestFusedDense:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
m
,
k
),
jnp
.
bfloat16
)
kernel_1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
k
,
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
subkeys
[
1
],
(
k
,
len
(
activation_type
)
,
n
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
k
)
kernel_2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
n
,
k
),
jnp
.
bfloat16
)
/
jnp
.
sqrt
(
n
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
k
,),
jnp
.
bfloat16
)
beta
=
None
# was tested in TestNorm
if
use_bias
:
bias_1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
n
),
jnp
.
bfloat16
)
bias_1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
n
),
jnp
.
bfloat16
)
bias_2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
k
,),
jnp
.
bfloat16
)
else
:
bias_1
=
None
...
...
@@ -963,7 +993,7 @@ class TestFusedDense:
value_n_grad_prim_func
=
value_and_grad
(
prim_func
,
range
(
6
))
value_n_grad_ref_func
=
value_and_grad
(
ref_func
,
range
(
6
))
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
else
1
n_iterations
=
3
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
else
1
for
_
in
range
(
n_iterations
):
prim_out
,
(
prim_x_grad
,
...
...
@@ -1039,19 +1069,19 @@ class TestGroupedDense:
subkeys
=
jax
.
random
.
split
(
key
,
len
(
shape_list
)
*
2
)
lhs_list
,
rhs_list
,
contracting_dims_list
=
[],
[],
[]
for
i
,
((
m
,
n
,
k
),
layout
)
in
enumerate
(
zip
(
shape_list
,
layout_list
)):
for
i
,
((
m
,
n
,
k
),
data_
layout
)
in
enumerate
(
zip
(
shape_list
,
layout_list
)):
lhs
=
jax
.
random
.
uniform
(
subkeys
[
2
*
i
],
(
m
if
layout
[
0
]
==
"N"
else
k
,
k
if
layout
[
0
]
==
"N"
else
m
),
(
m
if
data_
layout
[
0
]
==
"N"
else
k
,
k
if
data_
layout
[
0
]
==
"N"
else
m
),
dtype
=
dtype
,
)
rhs
=
jax
.
random
.
uniform
(
subkeys
[
2
*
i
+
1
],
(
k
if
layout
[
1
]
==
"N"
else
n
,
n
if
layout
[
1
]
==
"N"
else
k
),
(
k
if
data_
layout
[
1
]
==
"N"
else
n
,
n
if
data_
layout
[
1
]
==
"N"
else
k
),
dtype
=
dtype
,
)
lhs_contracting_dim
=
(
1
,)
if
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
layout
[
1
]
==
"N"
else
(
1
,)
lhs_contracting_dim
=
(
1
,)
if
data_
layout
[
0
]
==
"N"
else
(
0
,)
rhs_contracting_dim
=
(
0
,)
if
data_
layout
[
1
]
==
"N"
else
(
1
,)
contracting_dims
=
(
lhs_contracting_dim
,
rhs_contracting_dim
)
lhs_list
.
append
(
lhs
)
...
...
tests/jax/test_distributed_fused_attn.py
View file @
ab3e5a92
...
...
@@ -48,31 +48,7 @@ class TestDistributedSelfAttn:
# for loss and dbias
return
generate_collectives_count
(
allreduce
=
allreduce_total_bytes
,
allgather
=
0
,
other
=
0
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[
pytest
.
param
((
32
,
512
,
12
,
64
),
id
=
"32-512-12-64"
),
pytest
.
param
((
32
,
1024
,
16
,
128
),
id
=
"32-1024-16-128"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_bias_type, bias_shape"
,
[
pytest
.
param
(
AttnBiasType
.
NO_BIAS
,
None
,
id
=
"NO_BIAS"
),
pytest
.
param
(
AttnBiasType
.
PRE_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"PRE_SCALE_BIAS-1HSS"
),
pytest
.
param
(
AttnBiasType
.
POST_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"POST_SCALE_BIAS-1HSS"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_mask_type"
,
[
pytest
.
param
(
AttnMaskType
.
PADDING_MASK
,
id
=
"PADDING_MASK"
),
pytest
.
param
(
AttnMaskType
.
CAUSAL_MASK
,
id
=
"CAUSAL_MASK"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_self_attn
(
def
impl_test_self_attn
(
self
,
device_count
,
mesh_shape
,
...
...
@@ -83,7 +59,9 @@ class TestDistributedSelfAttn:
bias_shape
,
attn_mask_type
,
dtype
,
use_shardy
,
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
dropout_prob
=
0.0
is_training
=
True
...
...
@@ -137,6 +115,80 @@ class TestDistributedSelfAttn:
)
runner
.
test_backward
()
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[
pytest
.
param
((
32
,
512
,
12
,
64
),
id
=
"32-512-12-64"
),
pytest
.
param
((
32
,
1024
,
16
,
128
),
id
=
"32-1024-16-128"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_bias_type, bias_shape"
,
[
pytest
.
param
(
AttnBiasType
.
NO_BIAS
,
None
,
id
=
"NO_BIAS"
),
pytest
.
param
(
AttnBiasType
.
PRE_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"PRE_SCALE_BIAS-1HSS"
),
pytest
.
param
(
AttnBiasType
.
POST_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"POST_SCALE_BIAS-1HSS"
),
],
)
@
pytest
.
mark
.
parametrize
(
"attn_mask_type"
,
[
pytest
.
param
(
AttnMaskType
.
PADDING_MASK
,
id
=
"PADDING_MASK"
),
pytest
.
param
(
AttnMaskType
.
CAUSAL_MASK
,
id
=
"CAUSAL_MASK"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_self_attn
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_bias_type
,
bias_shape
,
attn_mask_type
,
dtype
,
):
self
.
impl_test_self_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_bias_type
,
bias_shape
,
attn_mask_type
,
dtype
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"attn_bias_type, bias_shape"
,
[
pytest
.
param
(
AttnBiasType
.
NO_BIAS
,
None
,
id
=
"NO_BIAS"
),
pytest
.
param
(
AttnBiasType
.
PRE_SCALE_BIAS
,
BiasShape
.
_1HSS
,
id
=
"PRE_SCALE_BIAS-1HSS"
),
],
)
def
test_self_attn_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
attn_bias_type
,
bias_shape
):
data_shape
=
(
32
,
512
,
12
,
64
)
self
.
impl_test_self_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_bias_type
,
bias_shape
,
AttnMaskType
.
PADDING_MASK
,
jnp
.
bfloat16
,
use_shardy
=
True
,
)
class
TestDistributedCrossAttn
:
...
...
@@ -203,37 +255,23 @@ class TestDistributedCrossAttn:
runner
.
test_backward
()
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest
.
param
([
2
,
128
,
8
,
128
],
id
=
"2-128xCP-8-128"
),
pytest
.
param
([
4
,
256
,
16
,
64
],
id
=
"4-256xCP-16-64"
),
],
)
@
pytest
.
mark
.
parametrize
(
"kv_groups"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
[
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_KVPACKED-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_SEPARATE-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
NO_MASK
,
id
=
"HD_KVPACKED-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
NO_MASK
,
id
=
"BSHD_SEPARATE-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
THD_THD_THD
,
AttnMaskType
.
PADDING_CAUSAL_MASK
,
id
=
"THD_SEPARATE-PADDING_CAUSAL"
,
),
],
)
@
pytest
.
mark
.
parametrize
(
"load_balanced"
,
[
pytest
.
param
(
True
,
id
=
"BALANCED"
),
pytest
.
param
(
False
,
id
=
"UNBALANCED"
)],
)
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
=
[
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_KVPACKED-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
CAUSAL_MASK
,
id
=
"BSHD_SEPARATE-CAUSAL"
),
pytest
.
param
(
QKVLayout
.
BSHD_BS2HD
,
AttnMaskType
.
NO_MASK
,
id
=
"HD_KVPACKED-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
BSHD_BSHD_BSHD
,
AttnMaskType
.
NO_MASK
,
id
=
"BSHD_SEPARATE-NO_MASK"
),
pytest
.
param
(
QKVLayout
.
THD_THD_THD
,
AttnMaskType
.
PADDING_CAUSAL_MASK
,
id
=
"THD_SEPARATE-PADDING_CAUSAL"
),
]
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
=
[
# Sequence lengths will be scaled by CP so that we don't run with tiny sizes.
pytest
.
param
([
2
,
128
,
8
,
128
],
id
=
"2-128xCP-8-128"
),
pytest
.
param
([
4
,
256
,
16
,
64
],
id
=
"4-256xCP-16-64"
),
]
class
TestDistributedContextParallelSelfAttn
:
def
impl_test_context_parallel_attn
(
...
...
@@ -249,7 +287,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
load_balanced
,
cp_strategy
,
use_shardy
,
use_scan_ring
=
False
,
):
if
qkv_layout
.
is_thd
():
if
cp_strategy
==
CPStrategy
.
ALL_GATHER
:
pytest
.
skip
(
"THD doesn't support all gather context parallelism."
)
if
not
load_balanced
and
cp_strategy
==
CPStrategy
.
RING
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
assert
not
use_scan_ring
or
cp_strategy
==
CPStrategy
.
RING
if
use_scan_ring
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"1"
else
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"0"
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
attn_bias_type
=
AttnBiasType
.
NO_BIAS
bias_shape
=
None
dropout_prob
=
0.0
...
...
@@ -324,7 +378,58 @@ class TestDistributedContextParallelSelfAttn:
pytest
.
skip
(
f
"Skipping
{
kv_groups
=
}
not multiple of
{
data_shape
=
}
or
{
tp_size
=
}
"
)
runner
.
test_backward
()
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
[:
1
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
def
test_context_parallel_allgather_attn_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_mask_type
,
dtype
,
qkv_layout
,
):
kv_groups
=
8
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
kv_groups
,
attn_mask_type
,
dtype
,
qkv_layout
,
load_balanced
=
True
,
cp_strategy
=
CPStrategy
.
ALL_GATHER
,
use_shardy
=
True
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"kv_groups"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
@
pytest
.
mark
.
parametrize
(
"load_balanced"
,
[
pytest
.
param
(
True
,
id
=
"BALANCED"
),
pytest
.
param
(
False
,
id
=
"UNBALANCED"
)],
)
def
test_context_parallel_allgather_attn
(
self
,
device_count
,
...
...
@@ -338,9 +443,7 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
load_balanced
,
):
if
qkv_layout
.
is_thd
():
pytest
.
skip
(
"THD doesn't support all gather context parallelism."
)
return
self
.
impl_test_context_parallel_attn
(
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
...
...
@@ -352,8 +455,23 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
load_balanced
,
CPStrategy
.
ALL_GATHER
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
)
@
pytest
.
mark
.
parametrize
(
"kv_groups"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
@
pytest
.
mark
.
parametrize
(
"load_balanced"
,
[
pytest
.
param
(
True
,
id
=
"BALANCED"
),
pytest
.
param
(
False
,
id
=
"UNBALANCED"
)],
)
@
pytest
.
mark
.
parametrize
(
"use_scan"
,
[
pytest
.
param
(
False
,
id
=
"NO_SCAN"
),
pytest
.
param
(
True
,
id
=
"USE_SCAN"
)],
...
...
@@ -372,14 +490,6 @@ class TestDistributedContextParallelSelfAttn:
load_balanced
,
use_scan
,
):
if
use_scan
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"1"
else
:
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
=
"0"
if
qkv_layout
.
is_thd
()
and
not
load_balanced
:
pytest
.
skip
(
"THD + ring doesn't support unbalanced context parallelism."
)
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
...
...
@@ -392,9 +502,46 @@ class TestDistributedContextParallelSelfAttn:
qkv_layout
,
load_balanced
,
CPStrategy
.
RING
,
use_shardy
=
False
,
use_scan_ring
=
use_scan
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_context_parallel_configs
()
)
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES
[:
1
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
)])
@
pytest
.
mark
.
parametrize
(
"qkv_layout, attn_mask_type"
,
DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS
,
)
def
test_context_parallel_ring_attn_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
attn_mask_type
,
dtype
,
qkv_layout
,
):
kv_groups
=
8
self
.
impl_test_context_parallel_attn
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
kv_groups
,
attn_mask_type
,
dtype
,
qkv_layout
,
load_balanced
=
True
,
cp_strategy
=
CPStrategy
.
RING
,
use_shardy
=
False
,
use_scan_ring
=
True
,
)
del
os
.
environ
[
"NVTE_FUSED_RING_ATTENTION_USE_SCAN"
]
return
class
TestReorderCausalLoadBalancing
:
...
...
tests/jax/test_distributed_layernorm.py
View file @
ab3e5a92
...
...
@@ -29,7 +29,7 @@ NORM_INPUT_SHAPES = {
}
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
...
...
@@ -86,6 +86,7 @@ class TestDistributedLayernorm:
@
pytest_parametrize_wrapper
(
"zero_centered_gamma"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_layernorm
(
self
,
device_count
,
...
...
@@ -97,7 +98,9 @@ class TestDistributedLayernorm:
zero_centered_gamma
,
shard_weights
,
fp8_recipe
,
use_shardy
,
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
epsilon
=
1e-6
ln_type
=
"layernorm"
q_dtype
=
jnp
.
float8_e4m3fn
...
...
@@ -168,6 +171,7 @@ class TestDistributedLayernorm:
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"shard_weights"
,
[
False
,
True
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_rmsnorm
(
self
,
device_count
,
...
...
@@ -178,7 +182,9 @@ class TestDistributedLayernorm:
dtype
,
shard_weights
,
fp8_recipe
,
use_shardy
,
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
epsilon
=
1e-6
ln_type
=
"rmsnorm"
q_dtype
=
jnp
.
float8_e4m3fn
...
...
tests/jax/test_distributed_layernorm_mlp.py
View file @
ab3e5a92
...
...
@@ -36,7 +36,7 @@ from transformer_engine.jax.quantize import QuantizerFactory
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
SUPPORTED_RECIPES
=
[]
if
is_fp8_supported
:
...
...
@@ -45,11 +45,17 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES
.
append
(
pytest
.
param
(
recipe
.
MXFP8BlockScaling
(),
id
=
"MXFP8BlockScaling"
))
DTYPES
=
[
jnp
.
bfloat16
,
jnp
.
float16
]
INPUT_SHAPE
=
[[
2
,
64
,
64
]]
# [batch, seqlen, hidden_in]
INPUT_SHAPE
=
[[
4
,
64
,
128
]]
# [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_TP_AXES
,
HIDDEN_AXES
)
DOT_1_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_AXES
)
DOT_2_INPUT_AXES
=
(
BATCH_AXES
,
SEQLEN_AXES
,
HIDDEN_TP_AXES
)
KERNEL_1_AXES
=
(
W_FSDP_AXES
,
W_JOINED_AXES
,
W_TP_AXES
)
KERNEL_2_AXES
=
(
W_TP_AXES
,
W_FSDP_AXES
)
LN_SCALE_AXES
=
(
W_NO_SHARD_AXES
,)
LN_BIAS_AXES
=
(
W_NO_SHARD_AXES
,)
BIAS_1_AXES
=
(
W_JOINED_AXES
,
W_TP_AXES
)
BIAS_2_AXES
=
(
W_NO_SHARD_AXES
,)
INTERMEDIATE
=
64
...
...
@@ -60,7 +66,6 @@ def generate_fsdp_and_tp_configs():
configs
.
append
(
[
2
,
(
1
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
)
if
is_devices_enough
(
4
):
configs
.
append
(
[
4
,
(
2
,
2
),
(
"fsdp"
,
"tp"
),
MeshResource
(
fsdp_resource
=
"fsdp"
,
tp_resource
=
"tp"
)]
...
...
@@ -80,13 +85,13 @@ class TestDistributedLayernormMLP:
x
=
jax
.
random
.
normal
(
subkeys
[
0
],
(
batch
,
seqlen
,
hidden_in
),
dtype
)
gamma
=
jax
.
random
.
normal
(
subkeys
[
5
],
(
hidden_in
,),
dtype
=
dtype
)
k1
=
jax
.
random
.
normal
(
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
subkeys
[
1
],
(
hidden_in
,
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
/
jnp
.
sqrt
(
hidden_in
)
k2
=
jax
.
random
.
normal
(
subkeys
[
2
],
(
INTERMEDIATE
,
hidden_out
),
dtype
)
/
jnp
.
sqrt
(
INTERMEDIATE
)
if
use_bias
:
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
*
INTERMEDIATE
),
dtype
)
b1
=
jax
.
random
.
normal
(
subkeys
[
3
],
(
len
(
activation_type
)
,
INTERMEDIATE
),
dtype
)
b2
=
jax
.
random
.
normal
(
subkeys
[
4
],
(
hidden_out
,),
dtype
)
else
:
b1
=
None
...
...
@@ -111,10 +116,12 @@ class TestDistributedLayernormMLP:
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
dot_1_input_axes
=
DOT_1_INPUT_AXES
dot_2_input_axes
=
DOT_2_INPUT_AXES
kernel_1_axes
=
KERNEL_1_AXES
kernel_2_axes
=
KERNEL_2_AXES
else
:
layernorm_input_axes
=
None
dot_1_input_axes
=
None
dot_2_input
_axes
=
None
dot_1_input_axes
=
dot_2_input_axes
=
None
kernel_1_axes
=
kernel_2
_axes
=
None
quantizer_sets
=
QuantizerFactory
.
create_set
(
n_quantizer_sets
=
2
)
...
...
@@ -130,21 +137,17 @@ class TestDistributedLayernormMLP:
norm_input_axes
=
layernorm_input_axes
,
dot_1_input_axes
=
dot_1_input_axes
,
dot_2_input_axes
=
dot_2_input_axes
,
kernel_1_axes
=
kernel_1_axes
,
kernel_2_axes
=
kernel_2_axes
,
activation_type
=
activation_type
,
quantizer_sets
=
quantizer_sets
,
)
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_fp8_mlp_primitive
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
def
_test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
=
mesh_config
layernorm_type
=
"rmsnorm"
...
...
@@ -168,12 +171,12 @@ class TestDistributedLayernormMLP:
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
mesh
=
Mesh
(
devices
,
mesh_axes
)
with
mesh
,
fp8_autocast
(
enabled
=
True
,
fp8_recipe
=
fp8_recipe
,
mesh_resource
=
mesh_resource
):
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
"tp"
))
k1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"fsdp"
,
None
,
"tp"
))
k2_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
,
"fsdp"
))
k1_
=
jax
.
device_put
(
k1
,
k1_sharding
)
k2_
=
jax
.
device_put
(
k2
,
k2_sharding
)
if
use_bias
:
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
"tp"
))
b1_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
,
"tp"
))
b1_
=
jax
.
device_put
(
b1
,
b1_sharding
)
else
:
b1_sharding
=
b1_
=
None
...
...
@@ -248,9 +251,59 @@ class TestDistributedLayernormMLP:
err_msg
=
f
"multi_grads[
{
i
}
] is not close"
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_mlp_grad
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_grad_shardy
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
# We don't test block scaling with Shardy because at the time of writing,
# it is not supported in JAX's scaled_matmul_stablehlo.
self
.
_test_layernorm_mlp_grad
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
=
recipe
.
DelayedScaling
(),
use_shardy
=
True
,
)
def
_test_layernorm_mlp
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
=
None
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
,
fp8_recipe
,
use_shardy
,
):
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
batch
,
seqlen
,
hidden_in
=
input_shape
layernorm_type
=
"rmsnorm"
...
...
@@ -269,7 +322,7 @@ class TestDistributedLayernormMLP:
activations
=
activation_type
,
use_bias
=
use_bias
,
)
params_single
=
ln_mlp_single
.
init
(
init_rngs
,
x
)
params_single
=
ln_mlp_single
.
init
(
init_rngs
,
x
,
deterministic
=
True
)
mlp_out_single
,
ln_out_single
=
ln_mlp_single
.
apply
(
params_single
,
x
,
deterministic
=
True
)
...
...
@@ -286,19 +339,19 @@ class TestDistributedLayernormMLP:
transpose_batch_sequence
=
False
,
intermediate_dim
=
INTERMEDIATE
,
activations
=
activation_type
,
scale_axes
=
(
W_NO_SHARD
_AXES
,
),
ln_bias_axes
=
(
W_NO_SHARD
_AXES
,
),
kernel_axes_1
=
(
W_FSDP_AXES
,
W_JOINED_AXES
,
W_TP
_AXES
)
,
kernel_axes_2
=
(
W_TP_AXES
,
W_FSDP
_AXES
)
,
scale_axes
=
LN_SCALE
_AXES
,
ln_bias_axes
=
LN_BIAS
_AXES
,
kernel_axes_1
=
KERNEL_1
_AXES
,
kernel_axes_2
=
KERNEL_2
_AXES
,
use_bias
=
use_bias
,
bias_axes_1
=
(
W_JOINED_AXES
,
W_TP
_AXES
)
,
bias_axes_2
=
(
W_NO_SHARD
_AXES
,
),
bias_axes_1
=
BIAS_1
_AXES
,
bias_axes_2
=
BIAS_2
_AXES
,
layernorm_input_axes
=
LAYERNORM_INPUT_AXES
,
dot_1_input_axes
=
DOT_1_INPUT_AXES
,
dot_2_input_axes
=
DOT_2_INPUT_AXES
,
name
=
"mlp"
,
)
params_sharded
=
ln_mlp_sharded
.
init
(
init_rngs
,
x
)
params_sharded
=
ln_mlp_sharded
.
init
(
init_rngs
,
x
,
deterministic
=
True
)
mlp_out_sharded
,
ln_out_sharded
=
ln_mlp_sharded
.
apply
(
params_sharded
,
x
,
deterministic
=
True
)
...
...
@@ -313,25 +366,38 @@ class TestDistributedLayernormMLP:
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"silu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
):
@
pytest_parametrize_wrapper
(
"use_shardy"
,
[
False
,
True
])
def
test_layernorm_mlp_layer
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_shardy
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
False
,
fp8_recipe
=
None
,
use_shardy
=
use_shardy
,
)
# TODO: debug
# @pytest.mark.skipif(not is_fp8_supported, reason=reason)
# @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs())
# @pytest_parametrize_wrapper(
# "activation_type", [("gelu",), ("gelu", "linear")]
# )
# @pytest_parametrize_wrapper("use_bias", [True, False])
# @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE)
# @pytest_parametrize_wrapper("dtype", DTYPES)
# @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES)
# def test_layernorm_fp8_mlp_layer(
# self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe
# ):
# self._test_layernorm_mlp(
# mesh_config, activation_type, use_bias, input_shape, dtype,
# use_fp8=True, fp8_recipe=fp8_recipe
# )
@
pytest
.
mark
.
skipif
(
not
is_fp8_supported
,
reason
=
reason
)
@
pytest_parametrize_wrapper
(
"mesh_config"
,
generate_fsdp_and_tp_configs
())
@
pytest_parametrize_wrapper
(
"activation_type"
,
[(
"gelu"
,),
(
"gelu"
,
"linear"
)])
@
pytest_parametrize_wrapper
(
"use_bias"
,
[
True
,
False
])
@
pytest_parametrize_wrapper
(
"input_shape"
,
INPUT_SHAPE
)
@
pytest_parametrize_wrapper
(
"dtype"
,
DTYPES
)
@
pytest_parametrize_wrapper
(
"fp8_recipe"
,
SUPPORTED_RECIPES
)
def
test_layernorm_mlp_layer_fp8
(
self
,
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
fp8_recipe
):
self
.
_test_layernorm_mlp
(
mesh_config
,
activation_type
,
use_bias
,
input_shape
,
dtype
,
use_fp8
=
True
,
fp8_recipe
=
fp8_recipe
,
use_shardy
=
False
,
)
tests/jax/test_distributed_softmax.py
View file @
ab3e5a92
...
...
@@ -28,14 +28,16 @@ class TestDistributedSoftmax:
all_reduce_loss_bytes
=
4
# 1 * FP32
return
generate_collectives_count
(
allreduce
=
all_reduce_loss_bytes
,
allgather
=
0
,
other
=
0
)
def
generate_inputs
(
self
,
shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
):
def
generate_inputs
(
self
,
shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
):
batch
,
_
,
sqelen
,
_
=
shape
x
=
random
.
normal
(
random
.
PRNGKey
(
1124
),
shape
,
dtype
=
dtype
)
if
softmax_type
==
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
:
mask
=
make_causal_mask
(
batch
,
sqelen
)
else
:
mask
=
make_self_mask
(
batch
,
sqelen
)
mask
=
make_self_mask
(
1
if
broadcast_batch_mask
else
batch
,
sqelen
)
if
not
bad_sharding
:
x_pspec
=
PartitionSpec
(
...
...
@@ -45,7 +47,11 @@ class TestDistributedSoftmax:
x_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
mesh_resource
.
tp_resource
)
mask_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
None
)
if
broadcast_batch_mask
:
mask_pspec
=
PartitionSpec
(
None
,
None
,
None
,
None
)
else
:
mask_pspec
=
PartitionSpec
(
mesh_resource
.
dp_resource
,
None
,
None
,
None
)
return
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
...
...
@@ -67,16 +73,7 @@ class TestDistributedSoftmax:
output
=
jax
.
nn
.
softmax
(
x
*
scale_factor
)
return
jnp
.
mean
(
output
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
12
,
128
,
128
],
[
64
,
16
,
1024
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
,
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
1.0
,
3.0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
def
test_softmax
(
def
impl_test_softmax
(
self
,
device_count
,
mesh_shape
,
...
...
@@ -87,15 +84,20 @@ class TestDistributedSoftmax:
scale_factor
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
use_shardy
,
):
if
broadcast_batch_mask
and
softmax_type
!=
SoftmaxType
.
SCALED_MASKED
:
pytest
.
skip
(
"Softmax type has no mask."
)
jax
.
config
.
update
(
"jax_use_shardy_partitioner"
,
use_shardy
)
target_func
=
partial
(
self
.
target_func
,
scale_factor
=
scale_factor
,
softmax_type
=
softmax_type
)
ref_func
=
partial
(
self
.
ref_func
,
scale_factor
=
scale_factor
,
dtype
=
dtype
)
(
x
,
mask
),
(
x_pspec
,
mask_pspec
)
=
self
.
generate_inputs
(
data_shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
data_shape
,
mesh_resource
,
softmax_type
,
dtype
,
bad_sharding
,
broadcast_batch_mask
)
collective_count_ref
=
self
.
generate_collectives_count_ref
()
devices
=
np
.
asarray
(
jax
.
devices
()[:
device_count
]).
reshape
(
*
mesh_shape
)
...
...
@@ -129,4 +131,70 @@ class TestDistributedSoftmax:
assert
"Sharding the hidden dimension is not supported"
in
str
(
w
),
(
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
f
"
{
str
(
w
)
}
"
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"data_shape"
,
[[
32
,
12
,
128
,
128
],
[
64
,
16
,
1024
,
1024
]])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
,
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
1.0
,
3.0
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"broadcast_batch_mask"
,
[
False
,
True
])
def
test_softmax
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
scale_factor
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
):
self
.
impl_test_softmax
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
,
softmax_type
,
scale_factor
,
dtype
,
bad_sharding
,
broadcast_batch_mask
,
use_shardy
=
False
,
)
@
pytest
.
mark
.
parametrize
(
"device_count,mesh_shape,mesh_axes,mesh_resource"
,
generate_configs
())
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
SoftmaxType
.
SCALED
,
SoftmaxType
.
SCALED_MASKED
])
@
pytest
.
mark
.
parametrize
(
"bad_sharding"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"broadcast_batch_mask"
,
[
False
,
True
])
def
test_softmax_shardy
(
self
,
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
softmax_type
,
bad_sharding
,
broadcast_batch_mask
,
):
self
.
impl_test_softmax
(
device_count
,
mesh_shape
,
mesh_axes
,
mesh_resource
,
data_shape
=
[
32
,
12
,
128
,
128
],
softmax_type
=
softmax_type
,
scale_factor
=
1.0
,
dtype
=
DTYPES
[
0
],
bad_sharding
=
bad_sharding
,
broadcast_batch_mask
=
broadcast_batch_mask
,
use_shardy
=
True
,
)
tests/jax/test_layer.py
View file @
ab3e5a92
...
...
@@ -39,7 +39,7 @@ def enable_fused_attn():
is_fp8_supported
,
reason
=
is_fp8_available
()
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
)
is_mxfp8_supported
,
reason
=
is_fp8_available
(
ScalingMode
.
MXFP8_1D_SCALING
)
QUANTIZE_RECIPES
=
[]
""" Find supported scaling modes"""
...
...
@@ -215,12 +215,53 @@ ATTRS = [
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs22
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"causal"
,
_KEY_OF_WINDOW_SIZE
:
None
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs23
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"causal"
,
_KEY_OF_FLOAT32_ATTENTION_LOGITS
:
True
,
},
# attrs24
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"no_mask"
,
},
# attrs25
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"no_mask"
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
# attrs26
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
# attrs27
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_SELF_ATTN_MASK_TYPE
:
"padding"
,
_KEY_OF_WINDOW_SIZE
:
None
,
},
# attrs28
{
_KEY_OF_TRANSPOSE_BS
:
False
,
_KEY_OF_RELATIVE_EMBEDDING
:
False
,
_KEY_OF_WINDOW_SIZE
:
(
2
,
2
),
},
]
ATTRS
=
[{
**
BASE_ATTRS
,
**
attr
}
for
attr
in
ATTRS
]
...
...
@@ -313,7 +354,7 @@ class BaseRunner:
test_others
,
test_layer
,
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
_
,
updated_quantize_meta
=
flax
.
core
.
pop
(
updated_state
[
0
],
QuantizeConfig
.
COLLECTION_NAME
)
...
...
@@ -370,13 +411,13 @@ class EncoderRunner(BaseRunner):
data_rng
=
jax
.
random
.
PRNGKey
(
2024
)
inputs
=
(
jax
.
random
.
normal
(
data_rng
,
data_shape
,
dtype
),)
padded_mask
=
jnp
.
zeros
((
batch
,
1
,
seqlen
,
seqlen
),
dtype
=
jnp
.
uint8
)
causal_mask
=
jnp
.
triu
(
jnp
.
ones
((
batch
,
1
,
seqlen
,
seqlen
),
dtype
=
jnp
.
uint8
),
k
=
1
)
mask_shape
=
(
batch
,
1
,
seqlen
,
seqlen
)
padded_mask
=
jnp
.
zeros
(
mask_shape
,
dtype
=
jnp
.
uint8
)
causal_mask
=
jnp
.
triu
(
jnp
.
ones
(
mask_shape
,
dtype
=
jnp
.
uint8
),
k
=
1
)
if
self
.
attrs
[
_KEY_OF_SELF_ATTN_MASK_TYPE
]
in
[
"causal"
,
"padding_causal"
]:
mask
=
causal_mask
else
:
mask
=
padded_mask
ref_masks
=
(
1
-
mask
,)
test_masks
=
(
None
,
mask
)
# The second arg of Transformer is encoded tokens.
...
...
tests/jax/test_softmax.py
View file @
ab3e5a92
...
...
@@ -18,6 +18,7 @@ from utils import assert_allclose
from
transformer_engine.jax.cpp_extensions
import
is_softmax_kernel_available
from
transformer_engine.jax.softmax
import
SoftmaxType
,
softmax
from
transformer_engine.jax.flax.module
import
Softmax
def
catch_unsupported
(
method
):
...
...
@@ -94,7 +95,6 @@ class SoftmaxRunner:
case
_
:
raise
ValueError
(
f
"Unknown
{
self
.
softmax_type
=
}
"
)
@
catch_unsupported
def
test_forward
(
self
):
"""
Test transformer_engine.jax.softmax.softmax fwd rule
...
...
@@ -104,7 +104,6 @@ class SoftmaxRunner:
reference_out
=
__class__
.
reference_softmax
(
self
.
logits
,
self
.
mask
,
self
.
scale_factor
)
assert_allclose
(
primitive_out
,
reference_out
,
dtype
=
self
.
dtype
)
@
catch_unsupported
def
test_backward
(
self
):
"""
Test transformer_engine.jax.softmax.softmax bwd rule
...
...
@@ -141,6 +140,50 @@ class SoftmaxRunner:
assert_allclose
(
primitive_grad_logits
,
reference_grad_logits
,
dtype
=
self
.
dtype
)
class
SoftmaxPrimitivesRunner
(
SoftmaxRunner
):
"""
Jax Softmax Primitives runner
"""
@
catch_unsupported
def
test_forward
(
self
):
return
super
().
test_forward
()
@
catch_unsupported
def
test_backward
(
self
):
return
super
().
test_backward
()
class
SoftmaxModuleRunner
:
"""
Jax Softmax Module runner
"""
module_runner
:
SoftmaxRunner
bias
:
None
def
__init__
(
self
,
module_runner
,
bias
):
self
.
module_runner
=
module_runner
self
.
bias
=
bias
def
test_forward
(
self
):
"""
Test transformer_engine.jax.flax.module.Softmax fwd rule
"""
runner
=
self
.
module_runner
runner
.
_setup_inputs
()
rng
=
jax
.
random
.
PRNGKey
(
0
)
softmax_module
=
Softmax
(
scale_factor
=
runner
.
scale_factor
,
softmax_type
=
runner
.
softmax_type
,
)
softmax_vars
=
softmax_module
.
init
(
rng
,
runner
.
logits
,
runner
.
mask
)
module_out
=
softmax_module
.
apply
(
softmax_vars
,
runner
.
logits
,
runner
.
mask
)
reference_out
=
runner
.
reference_softmax
(
runner
.
logits
,
runner
.
mask
,
runner
.
scale_factor
)
assert_allclose
(
module_out
,
reference_out
,
dtype
=
runner
.
dtype
)
# Run softmax primitives test
@
pytest
.
mark
.
parametrize
(
"b, s_q, s_kv, h"
,
[
...
...
@@ -165,7 +208,7 @@ class SoftmaxRunner:
pytest
.
param
(
jnp
.
float16
,
id
=
"FP16"
),
],
)
class
TestSoftmax
:
class
TestSoftmax
Primitives
:
"""
Test transformer_engine.jax.softmax.softmax
"""
...
...
@@ -175,7 +218,7 @@ class TestSoftmax:
"""
Test forward with parameterized configs
"""
runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
=
Softmax
Primitives
Runner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
.
test_forward
()
@
staticmethod
...
...
@@ -183,5 +226,48 @@ class TestSoftmax:
"""
Test forward with parameterized configs
"""
runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
=
Softmax
Primitives
Runner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
runner
.
test_backward
()
# Run Softmax module test
@
pytest
.
mark
.
parametrize
(
"b, s_q, s_kv, h"
,
[
pytest
.
param
(
8
,
16
,
16
,
16
,
id
=
"8-16-16-16"
),
pytest
.
param
(
8
,
512
,
512
,
16
,
id
=
"8-512-512-16"
),
pytest
.
param
(
2
,
8
,
16384
,
8
,
id
=
"2-8-16384-8"
),
# triggers backup framework implementation due to (s_q % 4) != 0
pytest
.
param
(
8
,
511
,
512
,
16
,
id
=
"8-511-512-16"
),
],
)
@
pytest
.
mark
.
parametrize
(
"scale_factor"
,
[
0.125
])
@
pytest
.
mark
.
parametrize
(
"softmax_type"
,
[
pytest
.
param
(
SoftmaxType
.
SCALED
,
id
=
"SCALED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_MASKED
,
id
=
"SCALED_MASKED"
),
pytest
.
param
(
SoftmaxType
.
SCALED_UPPER_TRIANG_MASKED
,
id
=
"SCALED_UPPER_TRIANG_MASKED"
),
],
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
pytest
.
param
(
jnp
.
bfloat16
,
id
=
"BF16"
),
pytest
.
param
(
jnp
.
float16
,
id
=
"FP16"
),
],
)
class
TestSoftmaxModule
:
"""
Test transformer_engine.jax.flax.module.Softmax
"""
@
staticmethod
def
test_forward
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
):
"""
Test forward with parameterized configs
"""
module_runner
=
SoftmaxRunner
(
b
,
s_q
,
s_kv
,
h
,
scale_factor
,
softmax_type
,
dtype
)
bias
=
None
runner
=
SoftmaxModuleRunner
(
module_runner
,
bias
)
runner
.
test_forward
()
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
View file @
ab3e5a92
...
...
@@ -21,7 +21,11 @@ from transformer_engine.common.recipe import (
)
import
transformer_engine.pytorch
as
te
from
transformer_engine.pytorch.tensor
import
QuantizedTensor
,
cast_master_weights_to_fp8
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Tensor
,
Float8CurrentScalingQuantizer
,
)
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
def
_get_raw_data
(
quantized_tensor
):
...
...
@@ -228,6 +232,273 @@ class MiniOptimizer:
weight
.
data
.
copy_
(
master_weight
)
class
MiniFSDP
:
def
__init__
(
self
,
weights
,
lr
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
self
.
weights
=
weights
self
.
lr
=
lr
self
.
dp_group
=
dp_group
# Flatten the weights and pad to align with world size
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
if
isinstance
(
w
,
QuantizedTensor
)
else
w
.
view
(
-
1
)
for
w
in
weights
]
if
isinstance
(
weights
[
0
],
QuantizedTensor
):
raw_data_list
=
[
_get_raw_data
(
w
).
view
(
-
1
)
for
w
in
weights
]
else
:
raw_data_list
=
[
w
.
view
(
-
1
)
for
w
in
weights
]
self
.
flatten_weight
,
original_length
=
self
.
_flatten_tensors_with_pad
(
raw_data_list
)
# Split flattened weights into shards
self
.
local_weight_shard
=
torch
.
chunk
(
self
.
flatten_weight
,
world_size
)[
rank
]
self
.
local_main_grad_shard
=
torch
.
zeros_like
(
self
.
local_weight_shard
)
shard_size
=
self
.
flatten_weight
.
size
(
0
)
//
world_size
# Map original tensors to flattened indices
tensor_indices
=
[]
cumulative_length
=
0
for
tensor
in
raw_data_list
:
length
=
tensor
.
size
(
0
)
tensor_indices
.
append
((
cumulative_length
,
cumulative_length
+
length
))
cumulative_length
+=
length
# Build shard index mappings
self
.
weight_indices
=
[]
self
.
shard_indices
=
[]
for
idx
,
(
start
,
end
)
in
enumerate
(
tensor_indices
):
shard_start
=
rank
*
shard_size
shard_end
=
shard_start
+
shard_size
adjusted_end
=
min
(
shard_end
,
original_length
)
if
start
<=
adjusted_end
and
end
>=
shard_start
:
start_idx
=
max
(
start
,
shard_start
)
end_idx
=
min
(
end
,
adjusted_end
)
self
.
weight_indices
.
append
((
start_idx
-
start
,
end_idx
-
start
))
self
.
shard_indices
.
append
((
start_idx
-
shard_start
,
end_idx
-
shard_start
))
else
:
self
.
weight_indices
.
append
((
None
,
None
))
self
.
shard_indices
.
append
((
None
,
None
))
if
isinstance
(
weights
[
idx
],
QuantizedTensor
):
replace_raw_data
(
weights
[
idx
],
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
)
else
:
weights
[
idx
].
data
=
self
.
flatten_weight
[
start
:
end
].
view
(
weights
[
idx
].
shape
)
# Initialize local model weights and high-precision master weights
self
.
local_weights
=
[]
self
.
master_weights
=
[]
for
i
,
weight
in
enumerate
(
self
.
weights
):
weight_start
,
weight_end
=
self
.
weight_indices
[
i
]
shard_start
,
shard_end
=
self
.
shard_indices
[
i
]
if
shard_start
is
not
None
and
shard_end
is
not
None
:
local_weight_shard
=
self
.
local_weight_shard
[
shard_start
:
shard_end
]
self
.
local_weights
.
append
(
local_weight_shard
)
if
isinstance
(
weight
,
QuantizedTensor
):
high_precision_init_val
=
weight
.
get_high_precision_init_val
().
view
(
-
1
)
master_weight_shard
=
high_precision_init_val
.
to
(
weight
.
device
).
float
()[
weight_start
:
weight_end
]
else
:
master_weight_shard
=
weight
.
detach
().
view
(
-
1
).
float
()[
weight_start
:
weight_end
]
self
.
master_weights
.
append
(
master_weight_shard
)
else
:
self
.
local_weights
.
append
(
None
)
self
.
master_weights
.
append
(
None
)
setattr
(
weight
,
"main_grad"
,
torch
.
zeros_like
(
weight
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
)
def
_flatten_tensors_with_pad
(
self
,
tensors
):
"""
Flatten the list of tensors and pad them to align with the world size.
Args:
tensors (list): List of tensors to flatten.
Returns:
tuple: Flattened tensor and its original length before padding.
"""
world_size
=
dist
.
get_world_size
(
self
.
dp_group
)
flatten_tensor
=
torch
.
cat
(
tensors
)
original_length
=
flatten_tensor
.
size
(
0
)
padding_needed
=
(
world_size
-
original_length
%
world_size
)
%
world_size
if
padding_needed
>
0
:
flatten_tensor
=
torch
.
cat
(
[
flatten_tensor
,
torch
.
zeros
(
padding_needed
,
dtype
=
flatten_tensor
.
dtype
)]
)
return
flatten_tensor
,
original_length
def
zero_grad
(
self
):
for
weight
in
self
.
weights
:
weight
.
grad
=
None
weight
.
main_grad
.
zero_
()
def
step
(
self
):
"""
Perform an optimization step for the distributed sharded model.
This method includes:
1. Gradient reduce-scatter: Synchronize gradients across all processes.
2. Master weight update: Update high-precision master weights using local gradients.
3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
4. Weight synchronization: All-gather updated weights across all processes.
Returns:
None
"""
# Step 1: Reduce-scatter the gradients
main_grad_buffer
,
_
=
self
.
_flatten_tensors_with_pad
(
[
weight
.
main_grad
.
view
(
-
1
)
for
weight
in
self
.
weights
]
)
main_grad_buffer
=
main_grad_buffer
.
to
(
self
.
local_main_grad_shard
.
dtype
)
dist
.
reduce_scatter_tensor
(
self
.
local_main_grad_shard
,
main_grad_buffer
,
group
=
self
.
dp_group
)
# Step 2: Update the master weights
for
weight
,
master_weight
,
(
shard_start
,
shard_end
)
in
zip
(
self
.
weights
,
self
.
master_weights
,
self
.
shard_indices
):
if
master_weight
is
None
:
continue
# Extract the local gradient shard for this weight
grad
=
self
.
local_main_grad_shard
[
shard_start
:
shard_end
]
# Update the master weight using gradient descent
master_weight
-=
grad
*
self
.
lr
# Step 3: Cast master weights to FP8 or BF16 precision
if
isinstance
(
self
.
weights
[
0
],
QuantizedTensor
):
local_weights
=
[]
for
local_weight
in
self
.
local_weights
:
if
local_weight
is
None
:
local_weights
.
append
(
None
)
continue
local_weights
.
append
(
local_weight
)
cast_master_weights_to_fp8
(
self
.
weights
,
self
.
master_weights
,
[
idx
[
0
]
for
idx
in
self
.
weight_indices
],
self
.
dp_group
,
local_weights
,
)
else
:
for
weight
,
master_weight
in
zip
(
self
.
local_weights
,
self
.
master_weights
):
if
master_weight
is
None
:
continue
# Copy updated master weights to local weights
weight
.
data
.
copy_
(
master_weight
)
# Step 4: All-gather updated weights across processes
dist
.
all_gather_into_tensor
(
self
.
flatten_weight
,
self
.
local_weight_shard
,
group
=
self
.
dp_group
)
def
_test_fsdp_cast_master_weights_to_fp8
(
quantization
,
dp_group
):
rank
=
dist
.
get_rank
(
dp_group
)
world_size
=
dist
.
get_world_size
(
dp_group
)
# Configuration constants
NUM_STEPS
=
100
SEED
=
12345
torch
.
manual_seed
(
SEED
)
torch
.
cuda
.
manual_seed
(
SEED
)
mock_groups
=
[
dist
.
new_group
(
ranks
=
[
i
])
for
i
in
range
(
world_size
)]
mock_group
=
mock_groups
[
rank
]
linear_kwargs
=
{
"params_dtype"
:
torch
.
bfloat16
,
"bias"
:
False
,
"fuse_wgrad_accumulation"
:
False
,
}
# Create model with FP8 weights
with
te
.
fp8
.
fp8_model_init
(
enabled
=
quantization
is
not
None
,
recipe
=
quantization_recipe
(
quantization
),
preserve_high_precision_init_val
=
True
,
):
model_fp8
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Create model with BF16 weights
model
=
nn
.
Sequential
(
te
.
Linear
(
128
,
256
,
**
linear_kwargs
),
te
.
Linear
(
256
,
256
*
3
,
**
linear_kwargs
),
te
.
Linear
(
256
*
3
,
128
,
**
linear_kwargs
),
)
# Make sure the BF16 model and FP8 model have the same initial weights
for
w_fp8
,
w
in
zip
(
model_fp8
.
parameters
(),
model
.
parameters
()):
high_precision_init_val
=
w_fp8
.
get_high_precision_init_val
()
w
.
data
.
copy_
(
high_precision_init_val
)
optimizer_fp8
=
MiniFSDP
([
w
for
w
in
model_fp8
.
parameters
()],
10.0
,
dp_group
)
optimizer
=
MiniFSDP
([
w
for
w
in
model
.
parameters
()],
10.0
,
dp_group
)
for
_
in
range
(
100
):
optimizer_fp8
.
zero_grad
()
optimizer
.
zero_grad
()
inputs
=
[
torch
.
randn
(
16
,
128
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
for
_
in
range
(
world_size
)
]
# Choose based on rank to make sure the inputs of different ranks are different.
x
=
inputs
[
rank
]
with
te
.
fp8
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y_fp8
=
model_fp8
(
x
)
with
te
.
fp8_autocast
(
enabled
=
quantization
is
not
None
,
fp8_recipe
=
quantization_recipe
(
quantization
),
fp8_group
=
mock_group
,
):
y
=
model
(
x
)
targets
=
[
torch
.
randn_like
(
y
)
for
_
in
range
(
world_size
)]
# Choose based on rank to make sure the targets of different ranks are different.
target
=
targets
[
rank
]
loss_fp8
=
nn
.
MSELoss
()(
y_fp8
,
target
)
loss
=
nn
.
MSELoss
()(
y
,
target
)
loss_fp8
.
backward
()
loss
.
backward
()
optimizer_fp8
.
step
()
optimizer
.
step
()
torch
.
testing
.
assert_close
(
loss_fp8
,
loss
,
atol
=
0
,
rtol
=
0
)
print
(
f
"✅ Successfully validated FSDP
{
NUM_STEPS
}
training steps with"
f
"
{
quantization
}
quantization"
)
def
_test_zero_1
(
dp_group
):
"""Make sure the implementation of zero-1 optimizer is correct"""
rank
=
dist
.
get_rank
(
dp_group
)
...
...
@@ -389,6 +660,7 @@ def main(argv=None, namespace=None):
dp_group
=
dist
.
new_group
(
backend
=
"nccl"
)
_test_zero_1
(
dp_group
)
_test_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
_test_fsdp_cast_master_weights_to_fp8
(
args
.
quantization
,
dp_group
)
dist
.
destroy_process_group
()
return
0
...
...
tests/pytorch/distributed/run_numerics.py
View file @
ab3e5a92
...
...
@@ -19,6 +19,7 @@ from transformer_engine.common.recipe import (
MXFP8BlockScaling
,
DelayedScaling
,
Float8CurrentScaling
,
Float8BlockScaling
,
Format
,
Recipe
,
)
...
...
@@ -50,6 +51,8 @@ def quantization_recipe() -> Recipe:
return
MXFP8BlockScaling
()
if
QUANTIZATION
==
"fp8_cs"
:
return
Float8CurrentScaling
()
if
QUANTIZATION
==
"fp8_block_scaling"
:
return
Float8BlockScaling
()
return
te
.
fp8
.
get_default_fp8_recipe
()
...
...
@@ -86,7 +89,7 @@ def main(argv=None, namespace=None):
# Quantization scheme
QUANTIZATION
=
args
.
quantization
if
QUANTIZATION
in
(
"fp8"
,
"mxfp8"
):
if
QUANTIZATION
in
(
"fp8"
,
"mxfp8"
,
"fp8_block_scaling"
):
global
SEQ_LEN
,
BATCH_SIZE
,
HIDDEN_SIZE
SEQ_LEN
=
32
BATCH_SIZE
=
32
...
...
@@ -298,6 +301,11 @@ def _loss_backward(output_single_node, output_distributed):
LOSS_FN
(
output_distributed
,
target
).
backward
()
def
_loss_backward_dw
(
model_single_node
,
model_distributed
):
model_single_node
.
backward_dw
()
model_distributed
.
backward_dw
()
def
_alloc_main_grad
(
model_single_node
,
model_distributed
):
for
model
in
[
model_single_node
,
model_distributed
]:
for
param
in
model
.
parameters
():
...
...
@@ -471,6 +479,10 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs):
# Compute loss and backpropagate
_loss_backward
(
output_single_node
,
output_distributed
)
# Compute delayed weight gradient
if
"delay_wgrad_compute"
in
kwargs
:
_loss_backward_dw
(
model_single_node
,
model_distributed
)
# Validate outputs and gradients
_check_outputs
(
output_single_node
,
output_distributed
)
...
...
@@ -492,6 +504,7 @@ def test_linear():
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
{
"params_dtype"
:
torch
.
float16
},
{
"delay_wgrad_compute"
:
True
},
]
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
,
"row"
]:
...
...
@@ -643,6 +656,10 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs
# Compute loss and backpropagate
_loss_backward
(
output_single_node
,
output_distributed
)
# Compute delayed weight gradient
if
"delay_wgrad_compute"
in
kwargs
:
_loss_backward_dw
(
model_single_node
,
model_distributed
)
# Validate outputs and gradients
_check_outputs
(
output_single_node
,
output_distributed
)
...
...
@@ -665,6 +682,7 @@ def test_layernorm_linear():
{
"params_dtype"
:
torch
.
float16
},
{
"zero_centered_gamma"
:
False
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
for
kwargs
in
kwargs_list
:
for
parallel_mode
in
[
"column"
]:
...
...
@@ -744,6 +762,9 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg
# Compute loss and backpropagate
_loss_backward
(
output_single_node
,
output_distributed
)
if
"delay_wgrad_compute"
in
kwargs
:
_loss_backward_dw
(
model_single_node
,
model_distributed
)
# Validate outputs and gradients
_check_outputs
(
output_single_node
,
output_distributed
)
...
...
@@ -769,6 +790,7 @@ def test_layernorm_mlp():
{
"fuse_wgrad_accumulation"
:
True
},
{
"return_bias"
:
True
},
{
"return_layernorm_output"
:
True
},
{
"delay_wgrad_compute"
:
True
},
]
for
kwargs
in
kwargs_list
:
...
...
tests/pytorch/distributed/test_numerics.py
View file @
ab3e5a92
...
...
@@ -28,6 +28,9 @@ if torch.cuda.device_count() < 2:
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
fp8_block_scaling_available
,
reason_for_no_fp8_block_scaling
=
(
FP8GlobalStateManager
.
is_fp8_block_scaling_available
()
)
TEST_ROOT
=
Path
(
__file__
).
parent
.
resolve
()
NUM_PROCS
:
int
=
min
(
4
,
torch
.
cuda
.
device_count
())
...
...
@@ -48,7 +51,7 @@ def _run_test(quantization):
all_boolean
=
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
])
@
pytest
.
mark
.
parametrize
(
"quantization"
,
[
None
,
"fp8"
,
"mxfp8"
,
"fp8_cs"
,
"fp8_block_scaling"
])
def
test_distributed
(
quantization
):
if
quantization
==
"fp8"
and
not
fp8_available
:
pytest
.
skip
(
reason_for_no_fp8
)
...
...
@@ -56,4 +59,6 @@ def test_distributed(quantization):
pytest
.
skip
(
reason_for_no_fp8
)
if
quantization
==
"mxfp8"
and
not
mxfp8_available
:
pytest
.
skip
(
reason_for_no_mxfp8
)
if
quantization
==
"fp8_block_scaling"
and
not
fp8_block_scaling_available
:
pytest
.
skip
(
reason_for_no_fp8_block_scaling
)
_run_test
(
quantization
)
tests/pytorch/references/blockwise_fp8_gemm_reference.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
typing
import
Tuple
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
fused_fma_kernel
(
y_ptr
,
x_ptr
,
s_ptr
,
M
,
N
,
y_str0
,
y_str1
,
BLOCK
:
tl
.
constexpr
=
128
):
pid
=
tl
.
program_id
(
0
)
idx
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
idx
<
M
*
N
row
=
idx
//
N
col
=
idx
%
N
y_offset
=
row
*
y_str0
+
col
*
y_str1
x_offset
=
row
*
N
+
col
s_offset
=
row
*
N
+
col
y
=
tl
.
load
(
y_ptr
+
y_offset
,
mask
=
mask
)
x
=
tl
.
load
(
x_ptr
+
x_offset
,
mask
=
mask
)
s
=
tl
.
load
(
s_ptr
+
s_offset
,
mask
=
mask
)
tl
.
store
(
y_ptr
+
y_offset
,
tl
.
fma
(
x
,
s
,
y
),
mask
=
mask
)
def
fused_fma
(
y
,
x
,
s
,
BLOCK
=
128
):
"""
Fused multiply-add operation (y = y + x * s).
PyTorch does not provide a direct FMA equivalent (torch.addcmul is not bitwise equivalent to this operation).
This function also supports cases where 'y' is non-contiguous in memory.
"""
assert
(
y
.
shape
==
x
.
shape
==
s
.
shape
and
y
.
dim
()
==
2
),
"All tensors must be 2D with the same shape"
assert
x
.
is_contiguous
()
and
s
.
is_contiguous
(),
"x and s must be contiguous"
M
,
N
=
y
.
shape
grid
=
((
M
*
N
+
BLOCK
-
1
)
//
BLOCK
,)
fused_fma_kernel
[
grid
](
y
,
x
,
s
,
M
,
N
,
*
y
.
stride
(),
BLOCK
)
return
y
class
CuBLASRefBlockwiseGemm
:
"""
A cuBLAS compatible reference implementation of subchannel GEMM.
"""
def
qgemm
(
self
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
demunged_sx
:
torch
.
Tensor
,
demunged_sw
:
torch
.
Tensor
,
quant_tile_shape_x
:
Tuple
[
int
,
int
],
quant_tile_shape_w
:
Tuple
[
int
,
int
],
bias
:
torch
.
Tensor
|
None
=
None
,
out
:
torch
.
Tensor
|
None
=
None
,
accumulate
:
bool
=
False
,
use_split_accumulator
:
bool
=
False
,
)
->
torch
.
Tensor
:
# demunge scale shapes for cuBLAS
is_a_1d_scaled
=
quant_tile_shape_x
[
0
]
==
1
is_b_1d_scaled
=
quant_tile_shape_w
[
0
]
==
1
M
,
K
=
qx
.
shape
N
,
K
=
qw
.
shape
# mm_tile_shape = (tile_m, tile_n, tile_k)
mm_tile_shape
=
(
quant_tile_shape_x
[
0
],
quant_tile_shape_w
[
0
],
quant_tile_shape_w
[
1
],
)
if
bias
is
not
None
and
bias
.
numel
():
# To match cuBLAS more closely when bias is applied,
# the reference accumulates into float32, and cast to
# bfloat16 is deferred until after the GEMM.
out_dtype_for_ref
=
torch
.
float32
else
:
out_dtype_for_ref
=
out_dtype
y
=
self
.
qgemm_blockwise_2d
(
qx
,
qw
,
out_dtype_for_ref
,
demunged_sx
,
demunged_sw
,
mm_tile_shape
,
use_split_accumulator
,
is_a_1d_scaled
,
is_b_1d_scaled
,
)
if
bias
is
not
None
and
bias
.
numel
():
y
+=
bias
y
=
y
.
to
(
dtype
=
out_dtype
)
# cublas accumulation first convert to output dtype, then accumulate.
if
accumulate
:
assert
out
is
not
None
y
=
y
+
out
else
:
assert
out
is
None
,
"Output tensor should be None when accumulate is False."
return
y
@
classmethod
def
qgemm_blockwise_2d
(
cls
,
qx
:
torch
.
Tensor
,
qw
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
sx
:
torch
.
Tensor
,
sw
:
torch
.
Tensor
,
mm_tile_shape
:
Tuple
[
int
,
int
,
int
],
use_split_accumulator
:
bool
,
is_a_1d_scaled
:
bool
,
is_b_1d_scaled
:
bool
,
)
->
torch
.
Tensor
:
"""
Difference between cuBLAS and CUTLASS GEMM implementations:
- cuBLAS accumulation equation: use different equation for each scaling mode.
- For accumulation C in epiloge, it first convert C to output dtype, then accumulate.
"""
M
,
K
=
qx
.
shape
N
,
K_w
=
qw
.
shape
assert
K
==
K_w
,
"K dimension mismatch between qx and qw"
tile_len
=
128
# Calculate grid sizes without padding
grid_m
=
(
M
+
tile_len
-
1
)
//
tile_len
grid_n
=
(
N
+
tile_len
-
1
)
//
tile_len
grid_k
=
(
K
+
tile_len
-
1
)
//
tile_len
block_m
,
block_n
,
block_k
=
mm_tile_shape
scale_m_per_tile
=
tile_len
//
block_m
scale_n_per_tile
=
tile_len
//
block_n
assert
block_k
==
tile_len
,
"block_k must be equal to tile_len"
# Notes on making the reference implementation numerically equivalent to Cast Blockwise FP8 GEMM:
# 1) When using split_accumulate in FP8 GEMM, every 4 QMMA partial accumulation results are accumulated into float32 registers.
# 2) Partial accumulation results are accumulated using FMA (Fused Multiply-Add) instructions to apply scaling factors, as in: y += partial_y * scale
y
=
torch
.
zeros
(
M
,
N
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
# Validate shapes of sx and sw
scale_m_per_tensor
=
(
M
+
block_m
-
1
)
//
block_m
scale_n_per_tensor
=
(
N
+
block_n
-
1
)
//
block_n
assert
sx
.
shape
==
(
scale_m_per_tensor
,
grid_k
,
),
f
"sx shape mismatch: expected (
{
scale_m_per_tensor
}
,
{
grid_k
}
), got
{
sx
.
shape
}
"
assert
sw
.
shape
==
(
scale_n_per_tensor
,
grid_k
,
),
f
"sw shape mismatch: expected (
{
scale_n_per_tensor
}
,
{
grid_k
}
), got
{
sw
.
shape
}
"
for
i
in
range
(
grid_m
):
m_start
=
i
*
tile_len
m_end
=
min
(
m_start
+
tile_len
,
M
)
m_size
=
m_end
-
m_start
for
j
in
range
(
grid_n
):
n_start
=
j
*
tile_len
n_end
=
min
(
n_start
+
tile_len
,
N
)
n_size
=
n_end
-
n_start
y_block
=
y
[
m_start
:
m_end
,
n_start
:
n_end
]
for
k
in
range
(
grid_k
):
k_start
=
k
*
tile_len
k_end
=
min
(
k_start
+
tile_len
,
K
)
k_size
=
k_end
-
k_start
qx_block
=
(
qx
[
m_start
:
m_end
,
k_start
:
k_end
].
clone
().
contiguous
()
)
# Shape: [m_size, k_size]
qw_block
=
(
qw
[
n_start
:
n_end
,
k_start
:
k_end
].
clone
().
contiguous
()
)
# Shape: [n_size, k_size]
# Extract scaling factors for the current blocks
sx_block
=
sx
[
i
*
scale_m_per_tile
:
(
i
+
1
)
*
scale_m_per_tile
,
k
].
unsqueeze
(
-
1
)
sw_block
=
sw
[
j
*
scale_n_per_tile
:
(
j
+
1
)
*
scale_n_per_tile
,
k
].
unsqueeze
(
0
)
# Perform qgemm with scaling factors fused in the GEMM
# Accumulate should be in float32 format, which aligns with the split_accumulate in FP8 GEMM
one
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
qx
.
device
)
y_partial
=
torch
.
_scaled_mm
(
qx_block
,
qw_block
.
t
(),
scale_a
=
one
,
scale_b
=
one
,
out_dtype
=
torch
.
float32
,
use_fast_accum
=
not
use_split_accumulator
,
)
# Accumulate the partial result
if
is_a_1d_scaled
and
is_b_1d_scaled
:
# 1Dx1D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial
=
y_partial
*
sx_block
# Fuse multiplication and addition to align with the split_accumulate in FP8 GEMM
# y_block.add_(y_partial, alpha=scale.item())
fused_fma
(
y_block
,
y_partial
,
sw_block
.
expand_as
(
y_partial
).
contiguous
(),
)
elif
not
is_a_1d_scaled
and
is_b_1d_scaled
:
# 2Dx1D
# CuBLAS accumulation equation: y += (y * scale_b) * scale_a
y_partial
=
y_partial
*
sw_block
fused_fma
(
y_block
,
y_partial
,
sx_block
.
expand_as
(
y_partial
).
contiguous
(),
)
elif
is_a_1d_scaled
and
not
is_b_1d_scaled
:
# 1Dx2D
# CuBLAS accumulation equation: y += (y * scale_a) * scale_b
y_partial
=
y_partial
*
sx_block
fused_fma
(
y_block
,
y_partial
,
sw_block
.
expand_as
(
y_partial
).
contiguous
(),
)
else
:
scale
=
sx_block
*
sw_block
fused_fma
(
y_block
,
y_partial
,
scale
.
expand_as
(
y_partial
).
contiguous
())
y
=
y
.
to
(
out_dtype
)
return
y
tests/pytorch/references/blockwise_quantizer_reference.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
dataclasses
import
math
import
torch
from
typing
import
Optional
,
Protocol
,
Tuple
from
references.quantize_scale_calc
import
scale_from_amax_tensor
@
dataclasses
.
dataclass
()
class
QuantizeResult
:
data
:
torch
.
Tensor
scale
:
torch
.
Tensor
data_t
:
Optional
[
torch
.
Tensor
]
scale_t
:
Optional
[
torch
.
Tensor
]
@
dataclasses
.
dataclass
()
class
CuBLASScaleMunger
:
def
munge_scale_shapes_for_backend
(
self
,
unmunged
:
QuantizeResult
,
tile_shape
:
Tuple
[
int
,
int
],
)
->
QuantizeResult
:
"""
cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed
so that for an (M, N) tensor, the scales are (RoundUpDiv(N, 128), RoundUp(M, 4))
For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4))
format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required
"""
def
_pad_inner_to_align
(
s
:
torch
.
Tensor
,
transpose
:
bool
)
->
torch
.
Tensor
:
if
transpose
:
s
=
s
.
transpose
(
-
1
,
-
2
).
contiguous
()
M
,
K
=
s
.
shape
if
K
%
4
==
0
:
return
s
k_pad
=
4
-
(
K
%
4
)
return
torch
.
nn
.
functional
.
pad
(
s
,
(
0
,
k_pad
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
s
=
_pad_inner_to_align
(
unmunged
.
scale
,
transpose
=
tile_shape
[
0
]
==
1
)
if
unmunged
.
scale_t
is
None
:
s_t
=
None
else
:
s_t
=
_pad_inner_to_align
(
unmunged
.
scale_t
,
transpose
=
tile_shape
[
0
]
==
1
)
return
QuantizeResult
(
unmunged
.
data
,
s
,
unmunged
.
data_t
,
s_t
)
@
classmethod
def
demunge_scale_shape_from_backend
(
cls
,
qtensor_shape
:
Tuple
[
int
,
int
],
scales
:
torch
.
Tensor
,
tile_shape
:
Tuple
[
int
,
int
],
)
->
torch
.
Tensor
:
"""
Inverse operation of munge_scale_shapes_for_backend
"""
if
tile_shape
[
0
]
!=
1
:
# 2D block quantized tensor may need padding stripped off
derived_scale_k_shape
=
math
.
ceil
(
qtensor_shape
[
1
]
/
tile_shape
[
1
])
else
:
derived_scale_k_shape
=
qtensor_shape
[
0
]
M
,
K
=
scales
.
shape
if
derived_scale_k_shape
!=
K
:
scales
=
scales
[:,
:
derived_scale_k_shape
].
contiguous
()
if
tile_shape
[
0
]
==
1
:
return
scales
.
transpose
(
-
1
,
-
2
).
contiguous
()
else
:
return
scales
@
dataclasses
.
dataclass
()
class
BlockwiseQuantizerReference
:
"""
A reference QuantizeOp for subchannel/block hybrid quantization.
Defers to ref GEMMs and quantizization formatting based on the backend.
"""
def
__init__
(
self
)
->
None
:
self
.
scale_munger
=
CuBLASScaleMunger
()
@
classmethod
def
_quantize_square_block_tiling
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
tile_len
:
int
,
*
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
eps
:
float
,
)
->
QuantizeResult
:
M
,
K
=
x
.
shape
pad_m_k
=
[
0
,
0
]
if
K
%
tile_len
!=
0
:
pad_m_k
[
1
]
=
tile_len
-
(
K
%
tile_len
)
if
M
%
tile_len
!=
0
:
pad_m_k
[
0
]
=
tile_len
-
(
M
%
tile_len
)
unpadded_m
,
unpadded_k
=
M
,
K
if
pad_m_k
[
0
]
!=
0
or
pad_m_k
[
1
]
!=
0
:
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
pad_m_k
[
1
],
0
,
pad_m_k
[
0
]),
mode
=
"constant"
,
value
=
0
).
contiguous
()
M
,
K
=
x
.
shape
x_tiled
=
x
.
reshape
(
M
//
tile_len
,
tile_len
,
K
//
tile_len
,
tile_len
)
amax_grid
=
(
torch
.
abs
(
x_tiled
.
transpose
(
-
3
,
-
2
))
.
reshape
(
M
//
tile_len
,
K
//
tile_len
,
tile_len
**
2
)
.
amax
(
dim
=-
1
)
).
float
()
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
x_dtype
=
x
.
dtype
,
amax
=
amax_grid
,
quant_dtype
=
quant_dtype
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
qx
=
x_tiled
*
scale
.
reshape
(
M
//
tile_len
,
1
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
if
unpadded_k
!=
K
or
unpadded_m
!=
M
:
qx
=
qx
[:
unpadded_m
,
:
unpadded_k
].
contiguous
()
if
return_transpose
:
# Valid because of square block sizes
qx_t
=
qx
.
transpose
(
-
1
,
-
2
).
contiguous
()
scale_inv_t
=
scale_inv
.
transpose
(
-
1
,
-
2
).
contiguous
()
else
:
qx_t
=
None
scale_inv_t
=
None
return
QuantizeResult
(
data
=
qx
,
scale
=
scale_inv
,
data_t
=
qx_t
,
scale_t
=
scale_inv_t
)
@
classmethod
def
_quantize_vectorwise_reference
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
tile_len
:
int
,
*
,
pow_2_scales
:
bool
,
eps
:
float
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
M
,
K
=
x
.
shape
dtype_max
=
torch
.
finfo
(
quant_dtype
).
max
x_tiled
=
x
.
reshape
(
M
,
K
//
tile_len
,
tile_len
)
amax_grid
=
torch
.
abs
(
x_tiled
).
amax
(
dim
=-
1
).
float
()
scale
,
scale_inv
,
_
=
scale_from_amax_tensor
(
x_dtype
=
x
.
dtype
,
amax
=
amax_grid
,
quant_dtype
=
quant_dtype
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
qx
=
x_tiled
*
scale
.
reshape
(
M
,
K
//
tile_len
,
1
)
qx
=
torch
.
clamp
(
qx
,
min
=-
dtype_max
,
max
=
dtype_max
)
qx
=
qx
.
to
(
dtype
=
quant_dtype
)
qx
=
qx
.
reshape
(
M
,
K
)
return
qx
,
scale_inv
@
classmethod
def
_quantize_vector_tiling
(
cls
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
tile_len
:
int
,
*
,
return_transpose
:
bool
,
pow_2_scales
:
bool
,
eps
:
float
,
)
->
QuantizeResult
:
M
,
K
=
x
.
shape
if
K
%
tile_len
==
0
:
qref_input
=
x
else
:
pad_amount
=
tile_len
-
(
K
%
tile_len
)
pad
=
(
0
,
pad_amount
)
qref_input
=
torch
.
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
qout_padded
,
scale_inv
=
cls
.
_quantize_vectorwise_reference
(
qref_input
,
quant_dtype
,
tile_len
=
tile_len
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
if
K
%
tile_len
==
0
:
qout
=
qout_padded
else
:
qout
=
qout_padded
[:,
:
K
].
contiguous
()
if
return_transpose
:
if
M
%
tile_len
==
0
:
qref_input
=
x
.
transpose
(
-
1
,
-
2
).
contiguous
()
else
:
amount_to_pad
=
tile_len
-
(
M
%
tile_len
)
pad
=
(
0
,
amount_to_pad
)
qref_input
=
torch
.
nn
.
functional
.
pad
(
x
.
transpose
(
-
1
,
-
2
),
pad
,
mode
=
"constant"
,
value
=
0
).
contiguous
()
qout_t_padded
,
scale_inv_t
=
cls
.
_quantize_vectorwise_reference
(
qref_input
,
quant_dtype
,
tile_len
=
tile_len
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
)
if
M
%
tile_len
==
0
:
qout_t
=
qout_t_padded
else
:
qout_t
=
qout_t_padded
[:,
:
M
].
contiguous
()
else
:
qout_t
,
scale_inv_t
=
None
,
None
return
QuantizeResult
(
data
=
qout
,
scale
=
scale_inv
,
data_t
=
qout_t
,
scale_t
=
scale_inv_t
)
def
ref_dequantize_rowwise
(
self
,
q
:
torch
.
Tensor
,
quant_tile_shape
:
Tuple
[
int
,
int
],
s
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
assert
q
.
dim
()
==
2
q_M
,
q_K
=
q
.
shape
s
=
self
.
scale_munger
.
demunge_scale_shape_from_backend
((
q_M
,
q_K
),
s
,
quant_tile_shape
)
assert
len
(
s
.
shape
)
==
2
m_tiles
,
k_tiles
=
s
.
shape
M
,
K
=
q
.
shape
unpadded_m
,
unpadded_k
=
M
,
K
if
M
%
quant_tile_shape
[
0
]
!=
0
or
K
%
quant_tile_shape
[
1
]
!=
0
:
m_pad_amount
=
(
quant_tile_shape
[
0
]
-
(
M
%
quant_tile_shape
[
0
]))
%
quant_tile_shape
[
0
]
k_pad_amount
=
(
quant_tile_shape
[
1
]
-
(
K
%
quant_tile_shape
[
1
]))
%
quant_tile_shape
[
1
]
q
=
torch
.
nn
.
functional
.
pad
(
q
,
(
0
,
k_pad_amount
,
0
,
m_pad_amount
),
mode
=
"constant"
,
value
=
0
).
contiguous
()
M
,
K
=
q
.
shape
q_tiled
=
q
.
reshape
(
m_tiles
,
quant_tile_shape
[
0
],
k_tiles
,
quant_tile_shape
[
1
])
result
=
q_tiled
.
to
(
dtype
)
*
s
.
reshape
(
m_tiles
,
1
,
k_tiles
,
1
)
result
=
result
.
view
(
M
,
K
).
to
(
dtype
)
if
M
!=
unpadded_m
or
K
!=
unpadded_k
:
result
=
result
[:
unpadded_m
,
:
unpadded_k
].
contiguous
()
return
result
def
quantize
(
self
,
x
:
torch
.
Tensor
,
quant_dtype
:
torch
.
dtype
,
return_transpose
:
bool
=
False
,
eps
:
float
=
0.0
,
pow_2_scales
:
bool
=
False
,
quant_tile_shape
:
Tuple
[
int
,
int
]
=
(
128
,
128
),
)
->
QuantizeResult
:
# sanity checks
assert
x
.
dim
()
==
2
assert
x
.
dtype
in
(
torch
.
float
,
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
,
),
"Unsupported input dtype."
assert
quant_dtype
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
,
),
"Unsupported quant dtype."
assert
quant_tile_shape
in
((
1
,
128
),
(
128
,
128
))
if
quant_tile_shape
[
0
]
==
1
:
# Quantize row-wise
return
self
.
scale_munger
.
munge_scale_shapes_for_backend
(
self
.
_quantize_vector_tiling
(
x
,
quant_dtype
,
tile_len
=
quant_tile_shape
[
1
],
return_transpose
=
return_transpose
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
),
quant_tile_shape
,
)
else
:
# Quantize block-wise
return
self
.
scale_munger
.
munge_scale_shapes_for_backend
(
self
.
_quantize_square_block_tiling
(
x
,
quant_dtype
,
tile_len
=
quant_tile_shape
[
0
],
return_transpose
=
return_transpose
,
pow_2_scales
=
pow_2_scales
,
eps
=
eps
,
),
quant_tile_shape
,
)
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment