Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
ace6e18e
Commit
ace6e18e
authored
Feb 04, 2026
by
lijian6
Committed by
niuhb
Feb 04, 2026
Browse files
modify quant test.
Signed-off-by:
lijian
<
lijian6@sugon.com
>
parent
2d655524
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
254 deletions
+49
-254
tests/test_low_latency_new.py
tests/test_low_latency_new.py
+38
-23
tests/test_low_latency_new_int8.py
tests/test_low_latency_new_int8.py
+0
-211
tests/utils.py
tests/utils.py
+11
-20
No files found.
tests/test_low_latency_new.py
View file @
ace6e18e
...
...
@@ -6,7 +6,7 @@ from functools import partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_
pg_back
,
per_token_cast_pc_
back
def
simulate_failure_and_skip
(
rank
:
int
,
api
:
Literal
[
"dispatch"
,
"combine"
,
"clean"
],
expected_masked_ranks
:
Set
[
int
]):
...
...
@@ -81,12 +81,11 @@ def test_main(num_tokens: int,
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
0
,
2
,
3
,
):
# 0: 不量化, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True)
dispatch_use_fp8
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
dispatch_use_fp8
else
(
False
,
):
for
quant_group_size
in
(
128
,
):
# 跳过不支持的情况
if
quant_type
==
3
and
fp8_round_scale
==
False
:
for
quant_type
in
(
0
,
1
,
2
,
3
,
):
# 0: 不量化, 1: int8, 2: FP8_E4M3, 3: FP8_UE8M0 (仅支持round_scale=True), 4: FP8_E5M2
dispatch_use_quant
=
quant_type
>
0
for
fp8_round_scale
in
(
False
,
True
)
if
quant_type
!=
3
else
(
True
,
):
for
quant_group_size
in
(
0
,
128
,)
if
quant_type
>=
2
else
(
0
,
):
if
quant_type
==
3
and
(
fp8_round_scale
==
False
or
quant_group_size
==
0
):
continue
num_times
+=
1
...
...
@@ -96,12 +95,21 @@ def test_main(num_tokens: int,
quant_type
=
quant_type
,
fp8_round_scale
=
fp8_round_scale
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
\
if
dispatch_use_fp8
else
packed_recv_x
.
clone
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_quant
else
packed_recv_x
if
not
dispatch_use_quant
:
simulated_gemm_x
=
packed_recv_x
.
clone
()
elif
quant_group_size
==
0
:
simulated_gemm_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
reshape
(
-
1
)).
view
(
packed_recv_x
[
0
].
shape
)
elif
quant_group_size
==
128
:
simulated_gemm_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
hidden
//
128
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
if
not
dispatch_use_quant
:
recv_x
=
packed_recv_x
[
i
]
elif
quant_group_size
==
0
:
recv_x
=
per_token_cast_pc_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
elif
quant_group_size
==
128
:
recv_x
=
per_token_cast_pg_back
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
...
...
@@ -119,8 +127,15 @@ def test_main(num_tokens: int,
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
dispatch_use_quant
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
assert
torch
.
equal
(
recv_x_amin
,
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
))
if
quant_group_size
!=
0
:
if
fp8_round_scale
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.007
else
:
...
...
@@ -130,7 +145,7 @@ def test_main(num_tokens: int,
if
not
fp8_round_scale
:
assert
(
recv_x_amin
==
j
-
rank_offset
).
sum
().
item
()
==
(
all_topk_idx
[
j
]
==
expert_id
).
sum
().
item
()
assert
(
recv_x
[
begin_idx
:
begin_idx
+
count
,
:
-
128
]
-
j
+
rank_offset
).
sum
().
item
()
==
0
if
dispatch_use_
fp8
:
if
dispatch_use_
quant
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
...
...
@@ -154,7 +169,7 @@ def test_main(num_tokens: int,
diff
=
calc_diff
(
current_x
*
topk_weights
.
masked_fill
(
topk_idx
==
-
1
,
0
).
sum
(
dim
=
1
).
view
(
-
1
,
1
),
combined_x
)
assert
torch
.
isnan
(
combined_x
).
sum
().
item
()
==
0
# if not fp8_round_scale:
assert
diff
<
(
9e-4
if
dispatch_use_
fp8
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_
fp8
=
{
dispatch_use_
fp8
}
, zero_copy=
{
zero_copy
}
'
assert
diff
<
(
9e-4
if
dispatch_use_
quant
else
1e-5
),
f
'Error: diff=
{
diff
}
, dispatch_use_
quant
=
{
dispatch_use_
quant
}
, zero_copy=
{
zero_copy
}
'
hash_value
^=
hash_tensor
(
combined_x
)
# noinspection PyShadowingNames
...
...
@@ -168,7 +183,7 @@ def test_main(num_tokens: int,
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
2
,
quant_group_size
=
128
,
quant_type
=
2
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
...
...
tests/test_low_latency_new_int8.py
deleted
100644 → 0
View file @
2d655524
import
argparse
import
random
import
os
import
torch
import
torch.distributed
as
dist
from
functools
import
partial
from
typing
import
Literal
,
Set
import
deep_ep
from
utils
import
init_dist
,
bench
,
bench_kineto
,
calc_diff
,
hash_tensor
,
per_token_cast_back_int8
def
test_main
(
num_tokens
:
int
,
hidden
:
int
,
num_experts
:
int
,
num_topk
:
int
,
rank
:
int
,
num_ranks
:
int
,
group
:
dist
.
ProcessGroup
,
buffer
:
deep_ep
.
Buffer
,
seed
:
int
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
assert
num_experts
%
num_ranks
==
0
num_local_experts
=
num_experts
//
num_ranks
# NOTES: the integers greater than 256 exceed the BF16 precision limit
rank_offset
=
128
assert
num_ranks
-
rank_offset
<
257
,
'Too many ranks (exceeding test precision limit)'
x
=
torch
.
ones
((
num_tokens
,
hidden
),
dtype
=
torch
.
bfloat16
,
device
=
'cuda'
)
*
(
rank
-
rank_offset
)
x
[:,
-
128
:]
=
torch
.
arange
(
num_tokens
,
device
=
'cuda'
).
to
(
torch
.
bfloat16
).
view
(
-
1
,
1
)
x_list
=
[
x
]
# # NOTES: the last one is for performance testing
# # Most of the values in the perf case is lower than the threshold, casting most channels
# x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1)
scores
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
+
1
topk_idx
=
torch
.
topk
(
scores
,
num_topk
,
dim
=-
1
,
largest
=
True
,
sorted
=
True
)[
1
]
topk_weights
=
torch
.
randn
((
num_tokens
,
num_topk
),
dtype
=
torch
.
float32
,
device
=
'cuda'
).
abs
()
# Randomly mask some positions
for
_
in
range
(
10
):
topk_idx
[
random
.
randint
(
0
,
num_tokens
-
1
),
random
.
randint
(
0
,
num_topk
-
1
)]
=
-
1
all_topk_idx
=
torch
.
empty
((
num_ranks
,
num_tokens
,
num_topk
),
dtype
=
topk_idx
.
dtype
,
device
=
'cuda'
)
dist
.
all_gather_into_tensor
(
all_topk_idx
,
topk_idx
,
group
=
group
)
# For failure simulation and shrink testing
mask_status
=
torch
.
zeros
((
num_ranks
,),
dtype
=
torch
.
int
,
device
=
'cuda'
)
# Check dispatch correctness
do_check
=
True
hash_value
,
num_times
=
0
,
0
for
current_x
in
x_list
:
for
return_recv_hook
in
(
False
,
True
):
for
quant_type
in
(
1
,
):
for
fp8_round_scale
in
(
False
,
):
for
quant_group_size
in
(
0
,
):
dispatch_use_fp8
=
quant_type
>
0
num_times
+=
1
for
_
in
range
(
1
):
packed_recv_x
,
packed_recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
quant_type
,
quant_group_size
=
quant_group_size
,
async_finish
=
not
return_recv_hook
,
return_recv_hook
=
return_recv_hook
)
hook
()
if
return_recv_hook
else
event
.
current_stream_wait
()
packed_recv_x
=
(
packed_recv_x
[
0
],
packed_recv_x
[
1
].
contiguous
())
if
dispatch_use_fp8
else
packed_recv_x
simulated_gemm_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
].
view
(
-
1
,
hidden
),
packed_recv_x
[
1
].
view
(
-
1
,
1
)).
view
(
packed_recv_x
[
0
].
shape
)
for
i
in
range
(
num_local_experts
if
do_check
else
0
):
expert_id
=
rank
*
num_local_experts
+
i
recv_x
=
per_token_cast_back_int8
(
packed_recv_x
[
0
][
i
],
packed_recv_x
[
1
][
i
])
if
dispatch_use_fp8
else
packed_recv_x
[
i
]
recv_count
,
recv_src_info
,
recv_layout_range
=
packed_recv_count
[
i
],
handle
[
0
][
i
],
handle
[
1
][
i
]
# Check expert indices
int_mask
=
(
2
**
32
)
-
1
num_valid_tokens
=
recv_count
.
item
()
assert
num_valid_tokens
==
(
recv_layout_range
&
int_mask
).
sum
().
item
(),
f
'
{
num_valid_tokens
}
!=
{
recv_layout_range
&
int_mask
}
.sum().item()'
assert
num_valid_tokens
==
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
(
),
f
'
{
num_valid_tokens
}
!=
{
(
all_topk_idx
==
expert_id
).
sum
(
dim
=
[
1
,
2
])[
mask_status
==
0
].
sum
().
item
()
}
'
if
num_valid_tokens
==
0
:
continue
# Check received data
if
current_x
is
x
:
recv_x
=
recv_x
[:
num_valid_tokens
]
recv_x_amin
=
recv_x
[:,
:
-
128
].
amin
(
dim
=-
1
)
recv_x_amax
=
recv_x
[:,
:
-
128
].
amax
(
dim
=-
1
)
recv_src_info
=
recv_src_info
[:
num_valid_tokens
]
assert
torch
.
equal
(
recv_x_amin
,
recv_x_amax
)
if
quant_type
==
1
:
assert
calc_diff
(
recv_x
[:,
-
1
],
recv_src_info
.
view
(
-
1
))
<
0.01
else
:
assert
(
recv_x
[:,
-
128
:]
-
recv_src_info
.
view
(
-
1
,
1
)
%
num_tokens
).
sum
().
item
()
==
0
# for j in range(num_ranks):
# if (not round_scale):
# check_tmp1 = (recv_x_amin == j - rank_offset).sum().item()
# check_tmp2 = (all_topk_idx[j] == expert_id).sum().item()
# print(f'rank: {rank}, j: {j}, check_tmp1: {check_tmp1}, check_tmp2: {check_tmp2}, diff: {abs(check_tmp1 - check_tmp2)}')
# assert abs(check_tmp1 - check_tmp2) < 3
# assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0
if
dispatch_use_fp8
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
0
][
i
,
:
num_valid_tokens
])
hash_value
^=
hash_tensor
(
packed_recv_x
[
1
][
i
,
:
num_valid_tokens
])
else
:
hash_value
^=
hash_tensor
(
packed_recv_x
[
i
,
:
num_valid_tokens
])
print
(
"dispatch int 8 pass"
)
# noinspection PyShadowingNames
def
large_gemm_with_hook
(
hook
):
mat_0
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_1
=
torch
.
randn
((
8192
,
8192
),
dtype
=
torch
.
float
)
mat_0
@
mat_1
hook
()
# noinspection PyShadowingNames
def
test_func
(
return_recv_hook
:
bool
):
recv_x
,
recv_count
,
handle
,
event
,
hook
=
\
buffer
.
low_latency_dispatch
(
current_x
,
topk_idx
,
num_tokens
,
num_experts
,
quant_type
=
1
,
quant_group_size
=
0
,
async_finish
=
False
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
combined_x
,
event
,
hook
=
buffer
.
low_latency_combine
(
simulated_gemm_x
,
topk_idx
,
topk_weights
,
handle
,
return_recv_hook
=
return_recv_hook
)
large_gemm_with_hook
(
hook
)
if
return_recv_hook
else
None
# Calculate bandwidth
scale_size
=
1
# hidden / 128
num_fp8_bytes
,
num_bf16_bytes
=
(
hidden
+
scale_size
*
4
+
16
),
hidden
*
2
num_dispatch_comm_bytes
,
num_combine_comm_bytes
=
0
,
0
for
i
in
range
(
num_tokens
):
num_selections
=
(
topk_idx
[
i
]
!=
-
1
).
sum
().
item
()
num_dispatch_comm_bytes
+=
num_fp8_bytes
*
num_selections
num_combine_comm_bytes
+=
num_bf16_bytes
*
num_selections
# Separate profiling
for
return_recv_hook
in
(
True
,
False
):
group
.
barrier
()
dispatch_t
,
combine_t
=
bench_kineto
(
partial
(
test_func
,
return_recv_hook
=
return_recv_hook
),
kernel_names
=
(
'dispatch'
,
'combine'
),
barrier_comm_profiling
=
True
,
suppress_kineto_output
=
True
,
num_kernels_per_period
=
2
if
return_recv_hook
else
1
)
if
not
return_recv_hook
:
print
(
f
'[rank
{
rank
}
] Dispatch bandwidth:
{
num_dispatch_comm_bytes
/
1e9
/
dispatch_t
:.
2
f
}
GB/s, avg_t=
{
dispatch_t
*
1e6
:.
2
f
}
us | '
f
'Combine bandwidth:
{
num_combine_comm_bytes
/
1e9
/
combine_t
:.
2
f
}
GB/s, avg_t=
{
combine_t
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
else
:
print
(
f
'[rank
{
rank
}
] Dispatch send/recv time:
{
dispatch_t
[
0
]
*
1e6
:.
2
f
}
+
{
dispatch_t
[
1
]
*
1e6
:.
2
f
}
us | '
f
'Combine send/recv time:
{
combine_t
[
0
]
*
1e6
:.
2
f
}
+
{
combine_t
[
1
]
*
1e6
:.
2
f
}
us'
,
flush
=
True
)
return
hash_value
# noinspection PyUnboundLocalVariable,PyShadowingNames
def
test_loop
(
local_rank
:
int
,
num_local_ranks
:
int
,
args
:
argparse
.
Namespace
):
rank
,
num_ranks
,
group
=
init_dist
(
local_rank
,
num_local_ranks
)
num_tokens
,
hidden
=
args
.
num_tokens
,
args
.
hidden
num_topk
,
num_experts
=
args
.
num_topk
,
args
.
num_experts
num_rdma_bytes
=
deep_ep
.
Buffer
.
get_low_latency_rdma_size_hint
(
num_tokens
,
hidden
,
num_ranks
,
num_experts
)
if
local_rank
==
0
:
print
(
f
'Allocating buffer size:
{
num_rdma_bytes
/
1e6
}
MB ...'
,
flush
=
True
)
buffer
=
deep_ep
.
Buffer
(
group
,
num_rdma_bytes
=
num_rdma_bytes
,
low_latency_mode
=
True
,
num_qps_per_rank
=
num_experts
//
num_ranks
,
allow_nvlink_for_low_latency_mode
=
not
args
.
disable_nvlink
,
explicitly_destroy
=
True
,
allow_mnnvl
=
args
.
allow_mnnvl
)
test_main
(
num_tokens
,
hidden
,
num_experts
,
num_topk
,
rank
,
num_ranks
,
group
,
buffer
,
seed
=
1
)
# Destroy the buffer runtime and communication group
buffer
.
destroy
()
dist
.
barrier
()
dist
.
destroy_process_group
()
if
__name__
==
'__main__'
:
# TODO: you may modify NUMA binding for less CPU overhead
# TODO: buggy with `num_tokens=512`
parser
=
argparse
.
ArgumentParser
(
description
=
'Test low-latency EP kernels'
)
parser
.
add_argument
(
'--num-processes'
,
type
=
int
,
default
=
8
,
help
=
'Number of processes to spawn (default: 8)'
)
parser
.
add_argument
(
'--num-tokens'
,
type
=
int
,
default
=
128
,
help
=
'Number of tokens (default: 128)'
)
parser
.
add_argument
(
'--hidden'
,
type
=
int
,
default
=
7168
,
help
=
'Hidden dimension size (default: 7168)'
)
parser
.
add_argument
(
'--num-topk'
,
type
=
int
,
default
=
8
,
help
=
'Number of top-k experts (default: 8)'
)
parser
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
256
,
help
=
'Number of experts (default: 288)'
)
parser
.
add_argument
(
'--allow-mnnvl'
,
action
=
"store_true"
,
help
=
'Allow MNNVL for communication'
)
parser
.
add_argument
(
'--disable-nvlink'
,
action
=
'store_true'
,
help
=
'Whether to disable NVLink for testing'
)
parser
.
add_argument
(
"--pressure-test"
,
action
=
'store_true'
,
help
=
'Whether to do pressure test'
)
parser
.
add_argument
(
"--shrink-test"
,
action
=
'store_true'
,
help
=
'Whether to simulate failure and test shrink mode'
)
parser
.
add_argument
(
'--use-logfmt'
,
action
=
'store_true'
,
help
=
'Whether to test LogFMT combine'
)
args
=
parser
.
parse_args
()
num_processes
=
args
.
num_processes
torch
.
multiprocessing
.
spawn
(
test_loop
,
args
=
(
num_processes
,
args
),
nprocs
=
num_processes
)
tests/utils.py
View file @
ace6e18e
...
...
@@ -57,28 +57,22 @@ def per_token_cast_to_fp8(x: torch.Tensor):
return
(
x_padded_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
aligned_n
)[:,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
per_token_cast_back
(
x
_fp8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
_fp8
.
numel
()
==
0
:
return
x
_fp8
.
to
(
torch
.
bfloat16
)
def
per_token_cast_
pg_
back
(
x
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x
.
numel
()
==
0
:
return
x
.
to
(
torch
.
bfloat16
)
assert
x
_fp8
.
dim
()
==
2
m
,
n
=
x
_fp8
.
shape
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
aligned_n
=
align_up
(
n
,
128
)
x_
fp8_
padded
=
torch
.
nn
.
functional
.
pad
(
x
_fp8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_padded
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
if
x_scales
.
dtype
==
torch
.
int
:
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
uint8
).
to
(
torch
.
int
)
<<
23
x_scales
=
x_scales
.
view
(
dtype
=
torch
.
float
)
x_fp32_padded
=
x_fp8_padded
.
to
(
torch
.
float32
).
view
(
x_fp8
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x_fp8
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_fp8_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
x_fp32_padded
=
x_padded
.
to
(
torch
.
float32
).
view
(
x
.
size
(
0
),
-
1
,
128
)
x_scales
=
x_scales
.
view
(
x
.
size
(
0
),
-
1
,
1
)
return
(
x_fp32_padded
*
x_scales
).
view
(
x_padded
.
shape
).
to
(
torch
.
bfloat16
)[:,:
n
].
contiguous
()
def
per_token_cast_back_int8
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
"""
x_int8: [m, n] int8 tensor
x_scales: [m, n] 或 [m, 1] 或 [m, n/128] 量化 scale float
return: [m, n] bf16 tensor
"""
def
per_token_cast_pc_back
(
x_int8
:
torch
.
Tensor
,
x_scales
:
torch
.
Tensor
):
if
x_int8
.
numel
()
==
0
:
return
x_int8
.
to
(
torch
.
bfloat16
)
...
...
@@ -86,12 +80,9 @@ def per_token_cast_back_int8(x_int8: torch.Tensor, x_scales: torch.Tensor):
m
,
n
=
x_int8
.
shape
aligned_n
=
align_up
(
n
,
128
)
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_int8_padded
=
torch
.
nn
.
functional
.
pad
(
x_int8
,
(
0
,
aligned_n
-
n
),
mode
=
'constant'
,
value
=
0
)
x_fp32_padded
=
x_int8_padded
.
to
(
torch
.
float32
).
view
(
m
,
-
1
,
1
)
x_scales
=
x_scales
.
view
(
m
,
-
1
,
1
).
to
(
torch
.
float32
)
# print(f'x_int8.shape: {x_int8.shape}, x_fp32_padded: {x_fp32_padded.shape}, x_scales: {x_scales.shape}')
x_deq
=
(
x_fp32_padded
*
x_scales
).
view
(
m
,
aligned_n
)
return
x_deq
[:,
:
n
].
to
(
torch
.
bfloat16
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment