Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
DeepEP
Commits
cf4514dc
Commit
cf4514dc
authored
Feb 04, 2026
by
lishen
Browse files
fp8量化细节调整
parent
44ec8bed
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
17 deletions
+16
-17
csrc/deep_ep.cu
csrc/deep_ep.cu
+3
-3
csrc/kernels/internode_ll.cu
csrc/kernels/internode_ll.cu
+12
-12
csrc/kernels/utils.cuh
csrc/kernels/utils.cuh
+1
-2
No files found.
csrc/deep_ep.cu
View file @
cf4514dc
...
@@ -1330,9 +1330,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
...
@@ -1330,9 +1330,9 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
auto
packed_recv_x_dtype
=
torch
::
kBFloat16
;
switch
(
quant_type
)
{
switch
(
quant_type
)
{
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
1
:
packed_recv_x_dtype
=
torch
::
kInt8
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
uz
;
break
;
case
2
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
uz
;
break
;
case
3
:
packed_recv_x_dtype
=
torch
::
kFloat8_e4m3fn
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2
fnuz
;
break
;
case
4
:
packed_recv_x_dtype
=
torch
::
kFloat8_e5m2
;
break
;
}
}
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
auto
packed_recv_x
=
torch
::
empty
({
num_local_experts
,
num_ranks
*
num_max_dispatch_tokens_per_rank
,
hidden
},
x
.
options
().
dtype
(
packed_recv_x_dtype
));
...
...
csrc/kernels/internode_ll.cu
View file @
cf4514dc
...
@@ -152,10 +152,10 @@ __forceinline__ __device__ void pack_quantized_values(
...
@@ -152,10 +152,10 @@ __forceinline__ __device__ void pack_quantized_values(
if
constexpr
(
kQuantType
==
4
)
{
if
constexpr
(
kQuantType
==
4
)
{
// FP8 E5M2
// FP8 E5M2
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E5M2
_FNUZ
);
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E5M2
);
}
else
{
}
else
{
// FP8 E4M3 或 UE8M0
// FP8 E4M3 或 UE8M0
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3
_FNUZ
);
fp8x2_ptr
[
j
/
2
]
=
__hip_cvt_float2_to_fp8x2
(
fp32x2
,
__HIP_SATFINITE
,
__HIP_E4M3
);
}
}
}
}
}
}
...
@@ -179,9 +179,9 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -179,9 +179,9 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
enum
class
QuantType
{
enum
class
QuantType
{
None
=
0
,
// 不进行量化
None
=
0
,
// 不进行量化
Int8
=
1
,
// 采用 Int8 量化
Int8
=
1
,
// 采用 Int8 量化
FP8_E4M3
=
2
,
// 采用 FP8 量化 __HIP_E4M3
_FNUZ
FP8_E4M3
=
2
,
// 采用 FP8 量化 __HIP_E4M3
FP8_UE8M0
=
3
,
// 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_UE8M0
=
3
,
// 采用 FP8 量化 DeepseekV3.1的 UE8M0
FP8_E5M2
=
4
// 采用 FP8 量化 __HIP_E5M2
_FNUZ
FP8_E5M2
=
4
// 采用 FP8 量化 __HIP_E5M2
};
};
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
const
auto
sm_id
=
static_cast
<
int
>
(
blockIdx
.
x
);
...
@@ -247,7 +247,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -247,7 +247,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
__shared__
float
channel_amaxf
[
kNumScales
];
__shared__
float
channel_amaxf
[
kNumScales
];
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
if
(
thread_id
<
kNumScales
)
{
if
(
thread_id
<
kNumScales
)
{
channel_amaxf
[
thread_id
]
=
kFP8Margin
;
channel_amaxf
[
thread_id
]
=
0.0
;
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -262,7 +262,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -262,7 +262,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// Calculate local amax
// Calculate local amax
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
auto
bf16_values
=
reinterpret_cast
<
hip_bfloat16
*>
(
&
int4_value
);
float
fp32_values
[
kNumElemsPerRead
];
float
fp32_values
[
kNumElemsPerRead
];
float
amax
=
kFP8Margin
,
scale
,
scale_inv
;
float
amax
=
0.0
,
scale
,
scale_inv
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
for
(
int
j
=
0
;
j
<
kNumElemsPerRead
;
++
j
)
{
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
fp32_values
[
j
]
=
static_cast
<
float
>
(
bf16_values
[
j
]);
...
@@ -294,7 +294,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -294,7 +294,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
__syncthreads
();
__syncthreads
();
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
if
constexpr
(
kUseQuant8Bit
&&
kQuantGroupSize
==
0
)
{
float
amax_per_token
=
kFP8Margin
;
float
amax_per_token
=
0.0
;
// 并行规约,计算每个token的amax
// 并行规约,计算每个token的amax
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
for
(
int
s
=
0
;
s
<
kNumScales
;
s
+=
kWarpSize
)
{
int
src_idx
=
s
+
lane_id
;
int
src_idx
=
s
+
lane_id
;
...
@@ -310,7 +310,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -310,7 +310,7 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
// 根据最大值计算scale
// 根据最大值计算scale
float
scale
,
scale_inv
;
float
scale
,
scale_inv
;
calculate_quant8bit_scales
<
kQuantType
>
(
amax_per_token
,
scale
,
scale_inv
);
calculate_quant8bit_scales
<
kQuantType
>
(
amax_per_token
,
scale
,
scale_inv
,
fp8_round_scale
);
if
(
thread_id
==
0
)
{
if
(
thread_id
==
0
)
{
rdma_x_scales
[
0
]
=
scale_inv
;
rdma_x_scales
[
0
]
=
scale_inv
;
}
}
...
@@ -344,8 +344,8 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
...
@@ -344,8 +344,8 @@ __global__ __launch_bounds__(16 * kWarpSize, 1) void
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
uint64_t
p2p_ptr
=
internode
::
shmem_get_p2p_ptr
((
void
*
)
dst_ptr
,
rank
,
dst_rank
);
if
(
p2p_ptr
==
0
)
{
// RDMA
if
(
p2p_ptr
==
0
)
{
// RDMA
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
src_ptr
,
internode_ll_putmem_nbi
((
void
*
)
dst_ptr
,
(
void
*
)
src_ptr
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_ranks
,
dst_rank
,
dst_expert_local_idx
,
num_bytes_per_msg
);
num_bytes_per_msg
);
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
}
else
{
// 本地 GPU 和 同一计算节点的 其他 GPU 地址
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
// NOTES: only 2 load iterations for 7K hidden with 8 unrolls
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
const
auto
*
src_int4_ptr
=
reinterpret_cast
<
const
int4
*>
(
src_ptr
);
...
@@ -571,9 +571,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
...
@@ -571,9 +571,9 @@ void dispatch(void* packed_recv_x, void* packed_recv_x_scales,
/*量化类型枚举
/*量化类型枚举
0 -> None 不量化,保持原始精度
0 -> None 不量化,保持原始精度
1 -> Int8 使用 INT8 对称量化
1 -> Int8 使用 INT8 对称量化
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3
_FNUZ
)
2 -> FP8_E4M3 使用 FP8 E4M3 格式 (__HIP_E4M3)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
3 -> FP8_UE8M0 使用 DeepSeekV3.1 提出的 UE8M0 格式 (仅支持round_scale=True)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2
_FNUZ
)
4 -> FP8_E5M2 使用 FP8 E5M2 格式 (__HIP_E5M2)
*/
*/
#define DISPATCH_LAUNCH_CASE(hidden) \
#define DISPATCH_LAUNCH_CASE(hidden) \
...
...
csrc/kernels/utils.cuh
View file @
cf4514dc
...
@@ -342,8 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
...
@@ -342,8 +342,7 @@ __device__ __forceinline__ dtype_t broadcast(dtype_t &ptr, int src_lane_idx) {
}
}
// 设置不同的量化方式的最大值与相反数
// 设置不同的量化方式的最大值与相反数
constexpr
float
kFP8Margin
=
0.0
;
constexpr
float
kFinfoAmaxE4M3
=
448.0
f
;
constexpr
float
kFinfoAmaxE4M3
=
240.0
f
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxInvE4M3
=
1.0
f
/
kFinfoAmaxE4M3
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxE5M2
=
57344.0
f
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
constexpr
float
kFinfoAmaxInvE5M2
=
1.0
f
/
kFinfoAmaxE5M2
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment