Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
d1bf39cf
Commit
d1bf39cf
authored
Jun 18, 2025
by
wenjh
Browse files
Fix lack of lds in vector_blockwise
Signed-off-by:
wenjh
<
wenjh@sugon.com
>
parent
3653fbfb
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
347 additions
and
200 deletions
+347
-200
tests/cpp/operator/test_cast_float8blockwise.cu
tests/cpp/operator/test_cast_float8blockwise.cu
+3
-33
tests/pytorch/test_float8_blockwise_scaling_exact.py
tests/pytorch/test_float8_blockwise_scaling_exact.py
+4
-4
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
...e/common/transpose/quantize_transpose_vector_blockwise.cu
+340
-163
No files found.
tests/cpp/operator/test_cast_float8blockwise.cu
View file @
d1bf39cf
...
...
@@ -263,30 +263,16 @@ void compare_scaling_factors(const std::string& name, const float* test, const f
void
compare_scaling_factors_one_dimensional_blocks
(
const
std
::
string
&
name
,
const
float
*
test
,
const
float
*
ref
,
const
size_t
rows
,
const
size_t
col_blocks
#ifdef __HIP_PLATFORM_AMD__
,
double
atol
=
0.
,
double
rtol
=
0.
#endif
)
{
const
size_t
col_blocks
)
{
const
size_t
test_stride
=
scale_align_stride
(
rows
);
for
(
int
i
=
0
;
i
<
rows
;
++
i
)
{
for
(
int
j
=
0
;
j
<
col_blocks
;
++
j
)
{
const
int
test_idx
=
i
+
test_stride
*
j
;
const
int
ref_idx
=
i
+
rows
*
j
;
#ifdef __HIP_PLATFORM_AMD__
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test
[
test_idx
]));
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref
[
ref_idx
]));
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
ASSERT_FALSE
(
mismatch
)
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
t
<<
" vs "
<<
r
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
#else
ASSERT_FALSE
(
test
[
test_idx
]
!=
ref
[
ref_idx
])
<<
"Error in "
<<
name
<<
std
::
endl
<<
"Mismatch: "
<<
test
[
test_idx
]
<<
" vs "
<<
ref
[
ref_idx
]
<<
" at index "
<<
test_idx
<<
","
<<
ref_idx
;
#endif
}
}
}
...
...
@@ -425,33 +411,17 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
float
atol
=
0.0
;
float
rtol
=
0.0
;
#ifdef __HIP_PLATFORM_AMD__
double
atol_scale
=
0.0
;
double
rtol_scale
=
0.0
;
if
(
itype
==
DType
::
kFloat32
)
{
atol_scale
=
1e-5
;
}
#endif
if
(
rowwise
)
{
compareResults
(
"output_c"
,
output_c
,
ref_output
.
get
(),
true
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv"
,
output_c
.
rowwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv
.
get
(),
rows
,
blocks_x
#ifdef __HIP_PLATFORM_AMD__
,
atol_scale
,
rtol_scale
#endif
);
ref_scale_inv
.
get
(),
rows
,
blocks_x
);
}
if
(
colwise
)
{
compareResults
(
"output_c_t"
,
output_c
,
ref_output_t
.
get
(),
false
,
atol
,
rtol
);
compare_scaling_factors_one_dimensional_blocks
(
"scale_inv_t"
,
output_c
.
columnwise_cpu_scale_inv_ptr
<
float
>
(),
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
#ifdef __HIP_PLATFORM_AMD__
,
atol_scale
,
rtol_scale
#endif
);
ref_scale_inv_t
.
get
(),
cols
,
blocks_x_t
);
}
}
...
...
tests/pytorch/test_float8_blockwise_scaling_exact.py
View file @
d1bf39cf
...
...
@@ -153,7 +153,7 @@ def check_quantization_block_tiling_versus_reference(
)
# Check
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
qx
.
float
(),
qx_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
# Zero out values that are don't care values
# Scale format has padding.
scale_mask
=
torch
.
ones
(
...
...
@@ -163,7 +163,7 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx
=
sx
*
scale_mask
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
torch
.
testing
.
assert_close
(
sx
,
sx_ref
,
atol
=
0.0
,
rtol
=
0.0
)
if
return_transpose
:
assert
qx_t
is
not
None
...
...
@@ -179,8 +179,8 @@ def check_quantization_block_tiling_versus_reference(
QuantizeResult
(
qx_t
,
scale_mask
,
None
,
None
),
tile_size
).
scale
sx_t
=
sx_t
*
scale_mask
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
if
quant_dtype
!=
torch
.
int8
else
1.0
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
2.5e-1
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
if
x_dtype
!=
torch
.
float32
else
1e-5
,
rtol
=
0.0
if
x_dtype
!=
torch
.
float32
else
5e-5
)
torch
.
testing
.
assert_close
(
qx_t
.
float
(),
qx_t_ref
.
float
(),
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_close
(
sx_t
,
sx_t_ref
,
atol
=
0.0
,
rtol
=
0.0
)
else
:
# should be None
assert
qx_t
is
None
and
qx_t_ref
is
None
...
...
transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu
View file @
d1bf39cf
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment