Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99caa491
Unverified
Commit
99caa491
authored
May 16, 2024
by
Jinzhen Lin
Committed by
GitHub
May 16, 2024
Browse files
[Kernel] add bfloat16 support for gptq marlin kernel (#4788)
parent
5c342570
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
246 additions
and
73 deletions
+246
-73
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+173
-67
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
+62
-0
tests/models/test_gptq_marlin.py
tests/models/test_gptq_marlin.py
+7
-2
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+4
-4
No files found.
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
99caa491
...
@@ -20,6 +20,11 @@
...
@@ -20,6 +20,11 @@
*/
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) static_assert(\
std::is_same<scalar_t, half>::value || std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
template
<
typename
T
>
inline
std
::
string
str
(
T
x
)
{
return
std
::
to_string
(
x
);
}
...
@@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
...
@@ -32,7 +37,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{}
int
size_k
,
int
block_rows
)
{}
template
<
const
int
num_bits
,
// number of bits used for weights
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
// dimension (batchsize) of the threadblock
...
@@ -72,31 +78,36 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
...
@@ -72,31 +78,36 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
#else
#else
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
// quantization scales
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
// output/accumulation.
__device__
inline
void
mma
(
const
FragA
&
a_frag
,
const
FragB
&
frag_b
,
template
<
typename
scalar_t
>
FragC
&
frag_c
)
{
__device__
inline
void
mma
(
const
typename
ScalarType
<
scalar_t
>::
FragA
&
a_frag
,
const
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragC
&
frag_c
)
{
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
a
=
reinterpret_cast
<
const
uint32_t
*>
(
&
a_frag
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
const
uint32_t
*
b
=
reinterpret_cast
<
const
uint32_t
*>
(
&
frag_b
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
float
*
c
=
reinterpret_cast
<
float
*>
(
&
frag_c
);
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
if
constexpr
(
std
::
is_same
<
scalar_t
,
half
>::
value
)
{
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
if
constexpr
(
std
::
is_same
<
scalar_t
,
nv_bfloat16
>::
value
)
{
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};
\n
"
:
"=f"
(
c
[
0
]),
"=f"
(
c
[
1
]),
"=f"
(
c
[
2
]),
"=f"
(
c
[
3
])
:
"r"
(
a
[
0
]),
"r"
(
a
[
1
]),
"r"
(
a
[
2
]),
"r"
(
a
[
3
]),
"r"
(
b
[
0
]),
"r"
(
b
[
1
]),
"f"
(
c
[
0
]),
"f"
(
c
[
1
]),
"f"
(
c
[
2
]),
"f"
(
c
[
3
]));
}
else
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
// memory, directly in tensor core layout.
__device__
inline
void
ldsm4
(
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
template
<
typename
scalar_t
>
__device__
inline
void
ldsm4
(
typename
ScalarType
<
scalar_t
>::
FragA
&
frag_a
,
const
void
*
smem_ptr
)
{
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
*
a
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_a
);
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
uint32_t
smem
=
static_cast
<
uint32_t
>
(
__cvta_generic_to_shared
(
smem_ptr
));
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];
\n
"
...
@@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) {
...
@@ -129,8 +140,15 @@ __device__ inline uint32_t prmt(uint32_t a) {
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// values. We mostly follow the strategy in the link below, with some small
// changes:
// changes:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
__device__
inline
FragB
dequant_4bit
(
int
q
)
{
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit
<
half
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
const
int
EX
=
0x64006400
;
...
@@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) {
...
@@ -142,7 +160,7 @@ __device__ inline FragB dequant_4bit(int q) {
const
int
SUB
=
0x64086408
;
const
int
SUB
=
0x64086408
;
const
int
MUL
=
0x2c002c00
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd480d480
;
const
int
ADD
=
0xd480d480
;
FragB
frag_b
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
...
@@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) {
...
@@ -151,7 +169,41 @@ __device__ inline FragB dequant_4bit(int q) {
return
frag_b
;
return
frag_b
;
}
}
__device__
inline
FragB
dequant_8bit
(
int
q
)
{
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_4bit
<
nv_bfloat16
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC308C308
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or bf16
// Reference:
// - FP16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16: https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit
<
half
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
...
@@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) {
...
@@ -161,7 +213,7 @@ __device__ inline FragB dequant_8bit(int q) {
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
static
constexpr
uint32_t
I8s_TO_F16s_MAGIC_NUM
=
0x64806480
;
FragB
frag_b
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
*
reinterpret_cast
<
const
half2
*>
(
&
I8s_TO_F16s_MAGIC_NUM
));
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
frag_b
[
1
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
...
@@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) {
...
@@ -169,34 +221,69 @@ __device__ inline FragB dequant_8bit(int q) {
return
frag_b
;
return
frag_b
;
}
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_8bit
<
nv_bfloat16
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
uint32_t
*
fp32_intermediates_casted
=
reinterpret_cast
<
uint32_t
*>
(
fp32_intermediates
);
static
constexpr
uint32_t
fp32_base
=
0x4B000000
;
fp32_intermediates_casted
[
0
]
=
__byte_perm
(
q
,
fp32_base
,
0x7650
);
fp32_intermediates_casted
[
1
]
=
__byte_perm
(
q
,
fp32_base
,
0x7652
);
fp32_intermediates_casted
[
2
]
=
__byte_perm
(
q
,
fp32_base
,
0x7651
);
fp32_intermediates_casted
[
3
]
=
__byte_perm
(
q
,
fp32_base
,
0x7653
);
fp32_intermediates
[
0
]
-=
8388736.
f
;
fp32_intermediates
[
1
]
-=
8388736.
f
;
fp32_intermediates
[
2
]
-=
8388736.
f
;
fp32_intermediates
[
3
]
-=
8388736.
f
;
uint32_t
*
bf16_result_ptr
=
reinterpret_cast
<
uint32_t
*>
(
&
frag_b
);
bf16_result_ptr
[
0
]
=
__byte_perm
(
fp32_intermediates_casted
[
0
],
fp32_intermediates_casted
[
1
],
0x7632
);
bf16_result_ptr
[
1
]
=
__byte_perm
(
fp32_intermediates_casted
[
2
],
fp32_intermediates_casted
[
3
],
0x7632
);
return
frag_b
;
}
// Multiply dequantized values by the corresponding quantization scale; used
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
// only for grouped quantization.
__device__
inline
void
scale
(
FragB
&
frag_b
,
FragS
&
frag_s
,
int
i
)
{
template
<
typename
scalar_t
>
half2
s
=
__half2half2
(
reinterpret_cast
<
__half
*>
(
&
frag_s
)[
i
]);
__device__
inline
void
scale
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s
,
int
i
)
{
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
scalar_t2
s
=
ScalarType
<
scalar_t
>::
num2num2
(
reinterpret_cast
<
scalar_t
*>
(
&
frag_s
)[
i
]);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s
);
}
}
// Same as above, but for act_order (each K is multiplied individually)
// Same as above, but for act_order (each K is multiplied individually)
__device__
inline
void
scale4
(
FragB
&
frag_b
,
FragS
&
frag_s_1
,
FragS
&
frag_s_2
,
template
<
typename
scalar_t
>
FragS
&
frag_s_3
,
FragS
&
frag_s_4
,
int
i
)
{
__device__
inline
void
scale4
(
typename
ScalarType
<
scalar_t
>::
FragB
&
frag_b
,
__half2
s_val_1_2
;
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_1
,
s_val_1_2
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_1
)[
i
];
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_2
,
s_val_1_2
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_2
)[
i
];
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_3
,
typename
ScalarType
<
scalar_t
>::
FragS
&
frag_s_4
,
__half2
s_val_3_4
;
int
i
)
{
s_val_3_4
.
x
=
reinterpret_cast
<
__half
*>
(
&
frag_s_3
)[
i
];
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
s_val_3_4
.
y
=
reinterpret_cast
<
__half
*>
(
&
frag_s_4
)[
i
];
scalar_t2
s_val_1_2
;
s_val_1_2
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_1
)[
i
];
s_val_1_2
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_2
)[
i
];
scalar_t2
s_val_3_4
;
s_val_3_4
.
x
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_3
)[
i
];
s_val_3_4
.
y
=
reinterpret_cast
<
scalar_t
*>
(
&
frag_s_4
)[
i
];
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
0
]
=
__hmul2
(
frag_b
[
0
],
s_val_1_2
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
frag_b
[
1
]
=
__hmul2
(
frag_b
[
1
],
s_val_3_4
);
}
}
// Given 2 floats multiply by 2 scales (halves)
// Given 2 floats multiply by 2 scales (halves)
__device__
inline
void
scale_float
(
float
*
c
,
FragS
&
s
)
{
template
<
typename
scalar_t
>
__half
*
s_ptr
=
reinterpret_cast
<
__half
*>
(
&
s
);
__device__
inline
void
scale_float
(
float
*
c
,
typename
ScalarType
<
scalar_t
>::
FragS
&
s
)
{
c
[
0
]
=
__fmul_rn
(
c
[
0
],
__half2float
(
s_ptr
[
0
]));
scalar_t
*
s_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
s
);
c
[
1
]
=
__fmul_rn
(
c
[
1
],
__half2float
(
s_ptr
[
1
]));
c
[
0
]
=
__fmul_rn
(
c
[
0
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
0
]));
c
[
1
]
=
__fmul_rn
(
c
[
1
],
ScalarType
<
scalar_t
>::
num2float
(
s_ptr
[
1
]));
}
}
// Wait until barrier reaches `count`, then lock for current threadblock.
// Wait until barrier reaches `count`, then lock for current threadblock.
...
@@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
...
@@ -287,7 +374,8 @@ __global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr,
}
}
}
}
template
<
const
int
num_bits
,
// number of bits used for weights
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
const
int
threads
,
// number of threads in a threadblock
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the threadblock
// dimension (batchsize) of the threadblock
...
@@ -323,6 +411,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -323,6 +411,12 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// ensures good utilization of all SMs for many kinds of shape and GPU
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
// reductions as possible.
using
Dtype
=
ScalarType
<
scalar_t
>
;
using
scalar_t2
=
typename
ScalarType
<
scalar_t
>::
scalar_t2
;
using
FragA
=
typename
ScalarType
<
scalar_t
>::
FragA
;
using
FragB
=
typename
ScalarType
<
scalar_t
>::
FragB
;
using
FragC
=
typename
ScalarType
<
scalar_t
>::
FragC
;
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
constexpr
int
pack_factor
=
32
/
num_bits
;
constexpr
int
pack_factor
=
32
/
num_bits
;
...
@@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -691,7 +785,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
int4
*
sh_a_stage
=
sh_a
+
a_sh_stage
*
pipe
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
ldsm4
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
ldsm4
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
&
sh_a_stage
[
a_sh_rd_trans
[
k
%
b_sh_wr_iters
][
i
]]);
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
int4
*
sh_b_stage
=
sh_b
+
b_sh_stage
*
pipe
;
#pragma unroll
#pragma unroll
...
@@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -835,43 +929,43 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
int
b_quant_shift
=
b_quant
>>
8
;
frag_b0
=
dequant_4bit
(
b_quant
);
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
(
b_quant_shift
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
}
else
{
}
else
{
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
frag_b0
=
dequant_8bit
(
b_quant_0
);
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
(
b_quant_1
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
}
}
// Apply scale to frag_b0
// Apply scale to frag_b0
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
scale4
<
scalar_t
>
(
frag_b0
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
0
);
}
else
{
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
scale
<
scalar_t
>
(
frag_b0
,
frag_s
[
k
%
2
][
j
],
0
);
}
}
}
}
// Apply scale to frag_b1
// Apply scale to frag_b1
if
constexpr
(
has_act_order
)
{
if
constexpr
(
has_act_order
)
{
scale4
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
scale4
<
scalar_t
>
(
frag_b1
,
act_frag_s
[
k
%
2
][
0
][
j
],
act_frag_s
[
k
%
2
][
1
][
j
],
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
act_frag_s
[
k
%
2
][
2
][
j
],
act_frag_s
[
k
%
2
][
3
][
j
],
1
);
}
else
{
}
else
{
if
constexpr
(
group_blocks
!=
-
1
)
{
if
constexpr
(
group_blocks
!=
-
1
)
{
scale
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
scale
<
scalar_t
>
(
frag_b1
,
frag_s
[
k
%
2
][
j
],
1
);
}
}
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
mma
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b0
,
frag_c
[
i
][
j
][
0
]);
mma
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
mma
<
scalar_t
>
(
frag_a
[
k
%
2
][
i
],
frag_b1
,
frag_c
[
i
][
j
][
1
]);
}
}
}
}
};
};
...
@@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -979,15 +1073,15 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]
+=
__half
2float
(
reinterpret_cast
<
__half
*>
(
&
c_red
)[
j
]);
Dtype
::
num
2float
(
reinterpret_cast
<
scalar_t
*>
(
&
c_red
)[
j
]);
}
}
}
}
if
(
!
last
)
{
if
(
!
last
)
{
int4
c
;
int4
c
;
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
2
*
4
;
j
++
)
{
reinterpret_cast
<
__half
*>
(
&
c
)[
j
]
=
reinterpret_cast
<
scalar_t
*>
(
&
c
)[
j
]
=
__
float2
half
(
reinterpret_cast
<
float
*>
(
Dtype
::
float2
num
(
reinterpret_cast
<
float
*>
(
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
&
frag_c
)[
4
*
2
*
4
*
(
i
/
4
)
+
4
*
j
+
(
i
%
4
)]);
}
}
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
C
[
c_gl_wr
+
c_gl_wr_delta_o
*
(
i
/
2
)
+
c_gl_wr_delta_i
*
(
i
%
2
)]
=
...
@@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1022,7 +1116,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
// We first reorder in shared memory to guarantee the most efficient final
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
// global write patterns
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
auto
write
=
[
&
](
int
idx
,
float
c0
,
float
c1
,
FragS
&
s
)
{
half2
res
=
__halves2half2
(
__
float2
half
(
c0
),
__
float2
half
(
c1
));
scalar_t2
res
=
Dtype
::
nums2num2
(
Dtype
::
float2
num
(
c0
),
Dtype
::
float2
num
(
c1
));
// For per-column quantization we finally apply the scale here (only for
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
// 4-bit)
...
@@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1030,7 +1124,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
res
=
__hmul2
(
res
,
s
[
0
]);
res
=
__hmul2
(
res
,
s
[
0
]);
}
}
((
half
2
*
)
sh
)[
idx
]
=
res
;
((
scalar_t
2
*
)
sh
)[
idx
]
=
res
;
};
};
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
...
@@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1192,14 +1286,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
0
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
0
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
0
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
scale_float
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
scale_float
<
scalar_t
>
(
reinterpret_cast
<
float
*>
(
&
frag_c
[
i
][
j
][
1
][
2
]),
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
frag_s
[
j
/
2
][
2
*
(
j
%
2
)
+
1
]);
}
}
}
}
...
@@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
...
@@ -1255,10 +1349,10 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
cudaFuncSetAttribute( \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
Marlin<
scalar_t,
NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
Marlin<
scalar_t,
NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
...
@@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
...
@@ -1462,6 +1556,7 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
marlin_mm_f16i4
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
s
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
void
*
g_idx
,
void
*
perm
,
void
*
a_tmp
,
int
prob_m
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
int
prob_n
,
int
prob_k
,
void
*
workspace
,
int
num_bits
,
...
@@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
...
@@ -1731,14 +1826,25 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
" is below min_workspace_size = "
,
min_workspace_size
);
" is below min_workspace_size = "
,
min_workspace_size
);
int
dev
=
a
.
get_device
();
int
dev
=
a
.
get_device
();
gptq_marlin
::
marlin_mm_f16i4
(
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
a
.
data_ptr
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
(),
b_scales
.
data_ptr
(),
gptq_marlin
::
marlin_mm_f16i4
<
half
>
(
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
(),
size_m
,
size_n
,
a
.
data_ptr
<
at
::
Half
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
Half
>
(),
b_scales
.
data_ptr
<
at
::
Half
>
(),
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
Half
>
(),
size_m
,
size_n
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
}
else
if
(
a
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
{
gptq_marlin
::
marlin_mm_f16i4
<
nv_bfloat16
>
(
a
.
data_ptr
<
at
::
BFloat16
>
(),
b_q_weight
.
data_ptr
(),
c
.
data_ptr
<
at
::
BFloat16
>
(),
b_scales
.
data_ptr
<
at
::
BFloat16
>
(),
g_idx
.
data_ptr
(),
perm
.
data_ptr
(),
a_tmp
.
data_ptr
<
at
::
BFloat16
>
(),
size_m
,
size_n
,
size_k
,
workspace
.
data_ptr
(),
num_bits
,
has_act_order
,
is_k_full
,
num_groups
,
group_size
,
dev
,
at
::
cuda
::
getCurrentCUDAStream
(
dev
),
thread_k
,
thread_n
,
sms
,
gptq_marlin
::
max_par
);
}
else
{
TORCH_CHECK
(
false
,
"gpt_marlin_gemm only supports bfloat16 and float16"
);
}
return
c
;
return
c
;
}
}
#endif
#endif
\ No newline at end of file
csrc/quantization/gptq_marlin/gptq_marlin_dtypes.cuh
0 → 100644
View file @
99caa491
#ifndef _data_types_cuh
#define _data_types_cuh
#include "gptq_marlin.cuh"
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace
gptq_marlin
{
template
<
typename
scalar_t
>
class
ScalarType
{
};
template
<
>
class
ScalarType
<
half
>
{
public:
using
scalar_t
=
half
;
using
scalar_t2
=
half2
;
// Matrix fragments for tensor core instructions; their precise layout is
// documented here:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
using
FragA
=
Vec
<
half2
,
4
>
;
using
FragB
=
Vec
<
half2
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
half2
,
1
>
;
static
__device__
float
inline
num2float
(
const
half
x
)
{
return
__half2float
(
x
);
}
static
__device__
half2
inline
num2num2
(
const
half
x
)
{
return
__half2half2
(
x
);
}
static
__device__
half2
inline
nums2num2
(
const
half
x1
,
const
half
x2
)
{
return
__halves2half2
(
x1
,
x2
);
}
static
__host__
__device__
half
inline
float2num
(
const
float
x
)
{
return
__float2half
(
x
);
}
};
template
<
>
class
ScalarType
<
nv_bfloat16
>
{
public:
using
scalar_t
=
nv_bfloat16
;
using
scalar_t2
=
nv_bfloat162
;
using
FragA
=
Vec
<
nv_bfloat162
,
4
>
;
using
FragB
=
Vec
<
nv_bfloat162
,
2
>
;
using
FragC
=
Vec
<
float
,
4
>
;
using
FragS
=
Vec
<
nv_bfloat162
,
1
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static
__device__
float
inline
num2float
(
const
nv_bfloat16
x
)
{
return
__bfloat162float
(
x
);
}
static
__device__
nv_bfloat162
inline
num2num2
(
const
nv_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
}
static
__device__
nv_bfloat162
inline
nums2num2
(
const
nv_bfloat16
x1
,
const
nv_bfloat16
x2
)
{
return
__halves2bfloat162
(
x1
,
x2
);
}
static
__host__
__device__
nv_bfloat16
inline
float2num
(
const
float
x
)
{
return
__float2bfloat16
(
x
);
}
#endif
};
}
#endif
tests/models/test_gptq_marlin.py
View file @
99caa491
...
@@ -14,6 +14,7 @@ import pytest
...
@@ -14,6 +14,7 @@ import pytest
import
torch
import
torch
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.rotary_embedding
import
_ROPE_DICT
from
.utils
import
check_logprobs_close
from
.utils
import
check_logprobs_close
...
@@ -52,7 +53,7 @@ MODELS = [
...
@@ -52,7 +53,7 @@ MODELS = [
@
pytest
.
mark
.
skipif
(
gptq_marlin_not_supported
,
@
pytest
.
mark
.
skipif
(
gptq_marlin_not_supported
,
reason
=
"gptq_marlin is not supported on this GPU type."
)
reason
=
"gptq_marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
,
"bfloat16"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
5
])
def
test_models
(
def
test_models
(
...
@@ -76,11 +77,15 @@ def test_models(
...
@@ -76,11 +77,15 @@ def test_models(
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
gptq_marlin_outputs
=
gptq_marlin_model
.
generate_greedy_logprobs
(
example_prompts
[:
-
1
],
max_tokens
,
num_logprobs
)
example_prompts
[:
-
1
],
max_tokens
,
num_logprobs
)
del
gptq_marlin_model
del
gptq_marlin_model
_ROPE_DICT
.
clear
()
# clear rope cache to avoid rope dtype error
# Run gptq.
# Run gptq.
# The naive gptq kernel doesn't support bf16 yet.
# Here we always compare fp16/bf16 gpt marlin kernel
# to fp16 gptq kernel.
gptq_model
=
vllm_runner
(
model_name
=
model_name
,
gptq_model
=
vllm_runner
(
model_name
=
model_name
,
revision
=
revision
,
revision
=
revision
,
dtype
=
dtype
,
dtype
=
"half"
,
quantization
=
"gptq"
,
quantization
=
"gptq"
,
max_model_len
=
MAX_MODEL_LEN
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
1
)
tensor_parallel_size
=
1
)
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
99caa491
...
@@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
...
@@ -99,7 +99,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -186,9 +186,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size
=
input_size
group_size
=
input_size
# Validate dtype
# Validate dtype
if
params_dtype
!=
torch
.
float16
:
if
params_dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]
:
raise
ValueError
(
raise
ValueError
(
f
"The params dtype must be float16 "
f
"The params dtype must be
float16, but got
{
params_dtype
}
"
)
f
"or b
float16, but got
{
params_dtype
}
"
)
# Validate output_size_per_partition
# Validate output_size_per_partition
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment