Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8896615a
Commit
8896615a
authored
Sep 29, 2025
by
zhushuang
Browse files
feat: add AWQ dequantize in moore gpu, with test pass
parent
3959c943
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
210 additions
and
1 deletion
+210
-1
src/infiniop/devices/moore/moore_kernel_common.h
src/infiniop/devices/moore/moore_kernel_common.h
+4
-1
src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h
...iniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h
+56
-0
src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.h
...finiop/ops/dequantize_awq/moore/dequantize_w42f16_moore.h
+8
-0
src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu
...iniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu
+127
-0
src/infiniop/ops/dequantize_awq/operator.cc
src/infiniop/ops/dequantize_awq/operator.cc
+15
-0
No files found.
src/infiniop/devices/moore/moore_kernel_common.h
View file @
8896615a
...
@@ -37,9 +37,12 @@ exp_(const float val) {
...
@@ -37,9 +37,12 @@ exp_(const float val) {
return
expf
(
val
);
return
expf
(
val
);
}
}
// Computes exp for long double on Moore GPU,
// casts to double to resolve ambiguous exp call,
// due to conflicting double/float definitions in MUSA math libraries.
__forceinline__
__device__
long
double
__forceinline__
__device__
long
double
exp_
(
const
long
double
val
)
{
exp_
(
const
long
double
val
)
{
return
exp
(
val
);
return
static_cast
<
long
double
>
(
exp
(
static_cast
<
double
>
(
val
)
))
;
}
}
__forceinline__
__device__
double
__forceinline__
__device__
double
...
...
src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_kernel.h
0 → 100644
View file @
8896615a
#pragma once
#include <musa_fp16.h> // 需要此头文件来支持 __half 和 __half2 类型
/**
* @brief 将一个包含8个4-bit整数的uint32_t反量化为8个half精度浮点数。
*
* 这是一个通用的 CUDA C++ 实现,用于替代原有的 PTX 汇编版本,
* 以便在不支持高级 PTX 指令(如 lop3.b32)的 GPU 上运行。
* 输出顺序匹配 PTX 的交错打包:v0, v4, v1, v5, v2, v6, v3, v7(经 signed 调整后)。
*
* @param source 输入的32位无符号整数,它打包了8个4-bit的数据。
* @return 一个 uint4 变量,其中包含8个反量化后的 half 值。
*/
__device__
__forceinline__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
// 步骤 1: 从一个 32-bit 源数据中解包出 8 个 4-bit 无符号整数。
// 源数据的内存布局被假定为 [v7, v6, v5, v4, v3, v2, v1, v0],
// 其中每个 'v' 都是一个 4-bit 的半字节 (nibble)。
const
unsigned
int
v0
=
(
source
>>
0
)
&
0x0F
;
const
unsigned
int
v1
=
(
source
>>
4
)
&
0x0F
;
const
unsigned
int
v2
=
(
source
>>
8
)
&
0x0F
;
const
unsigned
int
v3
=
(
source
>>
12
)
&
0x0F
;
const
unsigned
int
v4
=
(
source
>>
16
)
&
0x0F
;
const
unsigned
int
v5
=
(
source
>>
20
)
&
0x0F
;
const
unsigned
int
v6
=
(
source
>>
24
)
&
0x0F
;
const
unsigned
int
v7
=
(
source
>>
28
)
&
0x0F
;
// 步骤 2: 对于 signed 4-bit (s4),减去 8 以映射到 [-8, 7] 范围。
// 定义偏移量
__half
offset
=
__half
(
8
);
// 计算 signed 值
__half
hv0
=
__half
(
v0
)
-
offset
;
__half
hv1
=
__half
(
v1
)
-
offset
;
__half
hv2
=
__half
(
v2
)
-
offset
;
__half
hv3
=
__half
(
v3
)
-
offset
;
__half
hv4
=
__half
(
v4
)
-
offset
;
__half
hv5
=
__half
(
v5
)
-
offset
;
__half
hv6
=
__half
(
v6
)
-
offset
;
__half
hv7
=
__half
(
v7
)
-
offset
;
// 步骤 3: 将 half 值按 PTX 交错顺序打包成 __half2 并存入 result 中。
// 顺序:result_ptr[0]: low=hv0, high=hv4
// result_ptr[1]: low=hv1, high=hv5
// result_ptr[2]: low=hv2, high=hv6
// result_ptr[3]: low=hv3, high=hv7
// __halves2half2 函数:low 为第一个参数,high 为第二个参数。
uint4
result
;
__half2
*
result_ptr
=
reinterpret_cast
<
__half2
*>
(
&
result
);
result_ptr
[
0
]
=
__halves2half2
(
hv0
,
hv4
);
result_ptr
[
1
]
=
__halves2half2
(
hv1
,
hv5
);
result_ptr
[
2
]
=
__halves2half2
(
hv2
,
hv6
);
result_ptr
[
3
]
=
__halves2half2
(
hv3
,
hv7
);
return
result
;
}
src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.h
0 → 100644
View file @
8896615a
#ifndef __DEQUANTIZE_AWQ_MOORE_H__
#define __DEQUANTIZE_AWQ_MOORE_H__
#include "../dequantize_awq.h"
DESCRIPTOR
(
moore
)
#endif // __DEQUANTIZE_AWQ_MOORE_H__
src/infiniop/ops/dequantize_awq/moore/dequantize_w42f16_moore.mu
0 → 100644
View file @
8896615a
#include "../../../devices/moore/moore_handle.h"
#include "../../../devices/moore/moore_kernel_common.h"
#include "dequantize_w42f16_moore.h"
#include "dequantize_w42f16_kernel.h"
#include "../dequantize_awq.h"
#include <musa_fp16.h>
__global__ void __launch_bounds__(64)
dequantize_weights(int *__restrict__ B, half *__restrict__ scaling_factors,
int *__restrict__ zeros, half *__restrict__ C, int G) {
// static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
half *B_shared_ptr2 = B_shared;
int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = (blockIdx.y * blockDim.y + threadIdx.y);
int index1 = 8 * col + 8 * row * N;
half *C_ptr2 = C + index1;
int index2 = col + row * N;
int *B_ptr2 = B + index2;
int index3 = col + (int)(row / G) * N;
int *zeros_ptr2 = zeros + index3;
int index4 = 8 * col + (int)(row / G) * N * 8;
half *scaling_factors_ptr2 = scaling_factors + index4;
uint32_t zeros_loaded = *(uint32_t *)(zeros_ptr2);
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4 *)(scaling_factors_ptr2);
uint32_t B_loaded = *(uint32_t *)B_ptr2;
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
// Reinterpret uint4 components as __half2
__half2 *B_loaded_fp16_h2 = reinterpret_cast<__half2 *>(&B_loaded_fp16);
__half2 *B_loaded_zero_h2 = reinterpret_cast<__half2 *>(&B_loaded_zero);
__half2 *B_loaded_scale_h2 = reinterpret_cast<__half2 *>(&B_loaded_scale);
// Replace PTX sub.f16x2 with __hsub2 for each component
B_loaded_fp16_h2[0] = __hsub2(B_loaded_fp16_h2[0], B_loaded_zero_h2[0]);
B_loaded_fp16_h2[1] = __hsub2(B_loaded_fp16_h2[1], B_loaded_zero_h2[1]);
B_loaded_fp16_h2[2] = __hsub2(B_loaded_fp16_h2[2], B_loaded_zero_h2[2]);
B_loaded_fp16_h2[3] = __hsub2(B_loaded_fp16_h2[3], B_loaded_zero_h2[3]);
// Replace PTX fma.rn.f16x2 with __hfma2 for each component
B_loaded_fp16_h2[0] = __hfma2(B_loaded_fp16_h2[0], B_loaded_scale_h2[0], __float2half2_rn(0.0f));
B_loaded_fp16_h2[1] = __hfma2(B_loaded_fp16_h2[1], B_loaded_scale_h2[1], __float2half2_rn(0.0f));
B_loaded_fp16_h2[2] = __hfma2(B_loaded_fp16_h2[2], B_loaded_scale_h2[2], __float2half2_rn(0.0f));
B_loaded_fp16_h2[3] = __hfma2(B_loaded_fp16_h2[3], B_loaded_scale_h2[3], __float2half2_rn(0.0f));
// Store back to shared memory
*(uint4 *)B_shared_ptr2 = B_loaded_fp16;
for (int i = 0; i < 8; ++i) {
*(C_ptr2 + i) = B_shared[i];
}
}
namespace op::dequantize_awq::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) {
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto result = DequantizeAWQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc);
*desc_ptr = new Descriptor(
0,
new Opaque{handle->internal()},
result.take(),
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t
Descriptor::calculate(
void *workspace,
size_t workspace_size,
void *out,
const void *qweight,
const void *scales,
const void *zeros,
void *stream) const {
int in_features = _info.in_features();
int out_features = _info.out_features();
int group_size = in_features / _info.num_groups();
// ==================== 默认配置, 固定为 8 ====================
constexpr int BLOCK_X = 8;
constexpr int BLOCK_Y = 8;
int x_blocks = (out_features + BLOCK_X - 1) / BLOCK_X;
int y_blocks = (in_features + BLOCK_Y - 1) / BLOCK_Y;
dim3 num_blocks(x_blocks, y_blocks);
dim3 threads_per_block(BLOCK_X, BLOCK_Y);
// =====================================================
half *out_ = reinterpret_cast<half *>(out);
int *qweight_ = const_cast<int *>(reinterpret_cast<const int *>(qweight));
half *scales_ = const_cast<half *>(reinterpret_cast<const half *>(scales));
int *zeros_ = const_cast<int *>(reinterpret_cast<const int *>(zeros));
dequantize_weights<<<num_blocks, threads_per_block, 0, reinterpret_cast<musaStream_t>(stream)>>>(
qweight_, scales_, zeros_, out_, group_size);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::dequantize_awq::moore
src/infiniop/ops/dequantize_awq/operator.cc
View file @
8896615a
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
#include "nvidia/dequantize_w42f16_nvidia.cuh"
#include "nvidia/dequantize_w42f16_nvidia.cuh"
#endif
#endif
#ifdef ENABLE_MOORE_API
#include "moore/dequantize_w42f16_moore.h"
#endif
__C
infiniStatus_t
infiniopCreateDequantizeAWQDescriptor
(
__C
infiniStatus_t
infiniopCreateDequantizeAWQDescriptor
(
infiniopHandle_t
handle
,
infiniopHandle_t
handle
,
...
@@ -27,6 +30,9 @@ __C infiniStatus_t infiniopCreateDequantizeAWQDescriptor(
...
@@ -27,6 +30,9 @@ __C infiniStatus_t infiniopCreateDequantizeAWQDescriptor(
switch
(
handle
->
device
)
{
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetDequantizeAWQWorkspaceSize(infiniopDequantizeAWQDe
...
@@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetDequantizeAWQWorkspaceSize(infiniopDequantizeAWQDe
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -70,6 +79,9 @@ __C infiniStatus_t infiniopDequantizeAWQ(
...
@@ -70,6 +79,9 @@ __C infiniStatus_t infiniopDequantizeAWQ(
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
CALCULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
@@ -89,6 +101,9 @@ infiniopDestroyDequantizeAWQDescriptor(infiniopDequantizeAWQDescriptor_t desc) {
...
@@ -89,6 +101,9 @@ infiniopDestroyDequantizeAWQDescriptor(infiniopDequantizeAWQDescriptor_t desc) {
switch
(
desc
->
device_type
)
{
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
#ifdef ENABLE_NVIDIA_API
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
DELETE
(
INFINI_DEVICE_NVIDIA
,
nvidia
);
#endif
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#endif
default:
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment