Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6aa33cb2
Unverified
Commit
6aa33cb2
authored
Aug 12, 2024
by
Lucas Wilkinson
Committed by
GitHub
Aug 12, 2024
Browse files
[Misc] Use scalar type to dispatch to different `gptq_marlin` kernels (#7323)
parent
1137f343
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
334 additions
and
220 deletions
+334
-220
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+166
-31
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+168
-189
No files found.
csrc/core/scalar_type.hpp
View file @
6aa33cb2
...
...
@@ -20,7 +20,7 @@ namespace vllm {
//
class
ScalarType
{
public:
enum
NanRepr
:
int
64
_t
{
enum
NanRepr
:
u
int
8
_t
{
NAN_NONE
=
0
,
// nans are not supported
NAN_IEEE_754
=
1
,
// nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN
=
2
,
// nans are: exp all 1s, mantissa all 1s
...
...
@@ -28,33 +28,33 @@ class ScalarType {
NAN_REPR_ID_MAX
};
constexpr
ScalarType
(
bool
signed_
,
int
64
_t
exponent
,
int
64
_t
mantissa
,
int
64
_t
bias
,
bool
finite_values_only
=
false
,
constexpr
ScalarType
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
,
bool
signed_
,
int
32
_t
bias
,
bool
finite_values_only
=
false
,
NanRepr
nan_repr
=
NAN_IEEE_754
)
:
exponent
(
exponent
),
mantissa
(
mantissa
),
bias
(
bias
),
signed_
(
signed_
),
bias
(
bias
),
finite_values_only
(
finite_values_only
),
nan_repr
(
nan_repr
){};
static
constexpr
ScalarType
int_
(
int
64
_t
size_bits
,
int
64
_t
bias
=
0
)
{
return
ScalarType
(
true
,
0
,
size_bits
-
1
,
bias
);
static
constexpr
ScalarType
int_
(
u
int
8
_t
size_bits
,
int
32
_t
bias
=
0
)
{
return
ScalarType
(
0
,
size_bits
-
1
,
true
,
bias
);
}
static
constexpr
ScalarType
uint
(
int
64
_t
size_bits
,
int
64
_t
bias
=
0
)
{
return
ScalarType
(
false
,
0
,
size_bits
,
bias
);
static
constexpr
ScalarType
uint
(
u
int
8
_t
size_bits
,
int
32
_t
bias
=
0
)
{
return
ScalarType
(
0
,
size_bits
,
false
,
bias
);
}
// IEEE 754 compliant floating point type
static
constexpr
ScalarType
float_IEEE754
(
int
64
_t
exponent
,
int
64
_t
mantissa
)
{
static
constexpr
ScalarType
float_IEEE754
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
)
{
TORCH_CHECK
(
mantissa
>
0
&&
exponent
>
0
);
return
ScalarType
(
true
,
exponent
,
mantissa
,
0
,
false
,
NAN_IEEE_754
);
return
ScalarType
(
exponent
,
mantissa
,
true
,
0
,
false
,
NAN_IEEE_754
);
}
// IEEE 754 non-compliant floating point type
static
constexpr
ScalarType
float_
(
int
64
_t
exponent
,
int
64
_t
mantissa
,
static
constexpr
ScalarType
float_
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
,
bool
finite_values_only
,
NanRepr
nan_repr
)
{
TORCH_CHECK
(
nan_repr
<
NAN_REPR_ID_MAX
,
"Invalid NanRepr"
);
...
...
@@ -62,36 +62,121 @@ class ScalarType {
TORCH_CHECK
(
nan_repr
!=
NAN_IEEE_754
,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
);
return
ScalarType
(
true
,
exponent
,
mantissa
,
0
,
finite_values_only
,
return
ScalarType
(
exponent
,
mantissa
,
true
,
0
,
finite_values_only
,
nan_repr
);
}
int
64
_t
const
exponent
;
// size of the exponent field (0 for integer types)
int
64
_t
const
mantissa
;
// size of the mantissa field (size of the integer
u
int
8
_t
const
exponent
;
// size of the exponent field (0 for integer types)
u
int
8
_t
const
mantissa
;
// size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
int64_t
const
bias
;
// stored values equal value + bias,
// used for quantized type
bool
const
signed_
;
// flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t
const
bias
;
// stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool
const
finite_values_only
;
// i.e. no +/-inf if true
NanRepr
const
nan_repr
;
// how NaNs are represented
// (not applicable for integer types)
int64_t
size_bits
()
const
{
return
mantissa
+
exponent
+
is_signed
();
}
bool
is_signed
()
const
{
return
signed_
;
}
bool
is_integer
()
const
{
return
exponent
==
0
;
}
bool
is_floating_point
()
const
{
return
exponent
>
0
;
}
bool
is_ieee_754
()
const
{
using
Id
=
int64_t
;
private:
// Field size in id
template
<
typename
T_
>
static
constexpr
size_t
member_id_field_width
()
{
using
T
=
std
::
decay_t
<
T_
>
;
return
std
::
is_same_v
<
T
,
bool
>
?
1
:
sizeof
(
T
)
*
8
;
}
template
<
typename
Fn
,
typename
Init
,
typename
Member
,
typename
...
Rest
>
static
constexpr
auto
reduce_members_helper
(
Fn
f
,
Init
val
,
Member
member
,
Rest
...
rest
)
{
auto
new_val
=
f
(
val
,
member
);
if
constexpr
(
sizeof
...(
rest
)
>
0
)
{
return
reduce_members_helper
(
f
,
new_val
,
rest
...);
}
else
{
return
new_val
;
};
}
template
<
typename
Fn
,
typename
Init
>
constexpr
auto
reduce_members
(
Fn
f
,
Init
init
)
const
{
// Should be in constructor order for `from_id`
return
reduce_members_helper
(
f
,
init
,
exponent
,
mantissa
,
signed_
,
bias
,
finite_values_only
,
nan_repr
);
};
template
<
typename
Fn
,
typename
Init
>
static
constexpr
auto
reduce_member_types
(
Fn
f
,
Init
init
)
{
constexpr
auto
dummy_type
=
ScalarType
(
0
,
0
,
false
,
0
,
false
,
NAN_NONE
);
return
dummy_type
.
reduce_members
(
f
,
init
);
};
static
constexpr
auto
id_size_bits
()
{
return
reduce_member_types
(
[](
int
acc
,
auto
member
)
->
int
{
return
acc
+
member_id_field_width
<
decltype
(
member
)
>
();
},
0
);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr
Id
id
()
const
{
static_assert
(
id_size_bits
()
<=
sizeof
(
Id
)
*
8
,
"ScalarType id is too large to be stored"
);
auto
or_and_advance
=
[](
std
::
pair
<
Id
,
uint32_t
>
result
,
auto
member
)
->
std
::
pair
<
Id
,
uint32_t
>
{
auto
[
id
,
bit_offset
]
=
result
;
auto
constexpr
bits
=
member_id_field_width
<
decltype
(
member
)
>
();
return
{
id
|
(
int64_t
(
member
)
&
((
uint64_t
(
1
)
<<
bits
)
-
1
))
<<
bit_offset
,
bit_offset
+
bits
};
};
return
reduce_members
(
or_and_advance
,
std
::
pair
<
Id
,
uint32_t
>
{}).
first
;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static
constexpr
ScalarType
from_id
(
Id
id
)
{
auto
extract_and_advance
=
[
id
](
auto
result
,
auto
member
)
{
using
T
=
decltype
(
member
);
auto
[
tuple
,
bit_offset
]
=
result
;
auto
constexpr
bits
=
member_id_field_width
<
T
>
();
auto
extracted_val
=
static_cast
<
T
>
((
int64_t
(
id
)
>>
bit_offset
)
&
((
uint64_t
(
1
)
<<
bits
)
-
1
));
auto
new_tuple
=
std
::
tuple_cat
(
tuple
,
std
::
make_tuple
(
extracted_val
));
return
std
::
pair
<
decltype
(
new_tuple
),
int
>
{
new_tuple
,
bit_offset
+
bits
};
};
auto
[
tuple_args
,
_
]
=
reduce_member_types
(
extract_and_advance
,
std
::
pair
<
std
::
tuple
<>
,
int
>
{});
return
std
::
apply
([](
auto
...
args
)
{
return
ScalarType
(
args
...);
},
tuple_args
);
}
constexpr
int64_t
size_bits
()
const
{
return
mantissa
+
exponent
+
is_signed
();
}
constexpr
bool
is_signed
()
const
{
return
signed_
;
}
constexpr
bool
is_integer
()
const
{
return
exponent
==
0
;
}
constexpr
bool
is_floating_point
()
const
{
return
exponent
>
0
;
}
constexpr
bool
is_ieee_754
()
const
{
return
is_floating_point
()
&&
finite_values_only
==
false
&&
nan_repr
==
NAN_IEEE_754
;
}
bool
has_nans
()
const
{
return
is_floating_point
()
&&
nan_repr
!=
NAN_NONE
;
}
bool
has_infs
()
const
{
constexpr
bool
has_nans
()
const
{
return
is_floating_point
()
&&
nan_repr
!=
NAN_NONE
;
}
constexpr
bool
has_infs
()
const
{
return
is_floating_point
()
&&
finite_values_only
==
false
;
}
bool
has_bias
()
const
{
return
bias
!=
0
;
}
constexpr
bool
has_bias
()
const
{
return
bias
!=
0
;
}
private:
double
_floating_point_max
()
const
{
...
...
@@ -131,7 +216,7 @@ class ScalarType {
return
*
reinterpret_cast
<
double
*>
(
&
double_raw
);
}
std
::
variant
<
int64_t
,
double
>
_raw_max
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
_raw_max
()
const
{
if
(
is_floating_point
())
{
return
{
_floating_point_max
()};
}
else
{
...
...
@@ -141,7 +226,7 @@ class ScalarType {
}
}
std
::
variant
<
int64_t
,
double
>
_raw_min
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
_raw_min
()
const
{
if
(
is_floating_point
())
{
TORCH_CHECK
(
is_signed
(),
"We currently assume all floating point types are signed"
);
...
...
@@ -168,7 +253,7 @@ class ScalarType {
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
std
::
variant
<
int64_t
,
double
>
max
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
max
()
const
{
return
std
::
visit
(
[
this
](
auto
x
)
->
std
::
variant
<
int64_t
,
double
>
{
return
{
x
-
bias
};
},
_raw_max
());
...
...
@@ -176,7 +261,7 @@ class ScalarType {
// Min representable value for this scalar type.
// (accounting for bias if there is one)
std
::
variant
<
int64_t
,
double
>
min
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
min
()
const
{
return
std
::
visit
(
[
this
](
auto
x
)
->
std
::
variant
<
int64_t
,
double
>
{
return
{
x
-
bias
};
},
_raw_min
());
...
...
@@ -215,7 +300,7 @@ class ScalarType {
}
}
bool
operator
==
(
ScalarType
const
&
other
)
const
{
constexpr
bool
operator
==
(
ScalarType
const
&
other
)
const
{
return
mantissa
==
other
.
mantissa
&&
exponent
==
other
.
exponent
&&
bias
==
other
.
bias
&&
signed_
==
other
.
signed_
&&
finite_values_only
==
other
.
finite_values_only
&&
...
...
@@ -240,23 +325,59 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
using
Self
=
ScalarTypeTorch
;
using
SelfPtr
=
c10
::
intrusive_ptr
<
Self
>
;
static
void
check_size_bits
(
int64_t
size_bits
,
bool
signed_
)
{
TORCH_CHECK
(
size_bits
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"size_bits bit width is too large to be represented"
);
}
static
void
check_bias
(
int64_t
bias
)
{
using
Bias
=
decltype
(
std
::
declval
<
Self
>
().
bias
);
TORCH_CHECK
(
bias
<=
std
::
numeric_limits
<
Bias
>::
max
()
&&
bias
>=
std
::
numeric_limits
<
Bias
>::
min
(),
"bias too large or small to be represented"
);
}
static
void
check_exponent
(
int64_t
exponent
)
{
TORCH_CHECK
(
exponent
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
exponent
)
>::
max
(),
"exponent bit width is too large to be represented"
);
}
static
void
check_mantissa
(
int64_t
mantissa
)
{
TORCH_CHECK
(
mantissa
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"mantissa bit width is too large to be represented"
);
}
static
SelfPtr
int_
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
int_
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
uint
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
uint
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
float_IEEE754
(
int64_t
exponent
,
int64_t
mantissa
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_IEEE754
(
exponent
,
mantissa
));
}
static
SelfPtr
float_
(
int64_t
exponent
,
int64_t
mantissa
,
bool
finite_values_only
,
int64_t
nan_repr
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_
(
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
}
...
...
@@ -264,7 +385,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
template
<
typename
T
>
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
std
::
string
const
&
name
,
T
Base
::*
field
)
{
auto
getter_func
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
auto
getter_func
_helper
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
if
constexpr
(
std
::
is_member_function_pointer_v
<
decltype
(
field
)
>
)
{
return
(
self
.
get
()
->*
field
)();
}
else
{
...
...
@@ -272,6 +393,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
auto
getter_func
=
[
field
=
std
::
move
(
field
),
getter_func_helper
=
std
::
move
(
getter_func_helper
)](
SelfPtr
const
&
self
)
{
auto
val
=
getter_func_helper
(
self
);
// upconvert uint8_t, int32_t etc. to int64_t for python
if
constexpr
(
std
::
is_integral_v
<
T
>
)
{
return
static_cast
<
int64_t
>
(
val
);
}
else
{
return
val
;
}
};
cls
.
def_property
(
name
,
getter_func
);
}
...
...
@@ -340,6 +473,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
using
ScalarTypeId
=
int64_t
;
using
ScalarTypeTorchPtr
=
c10
::
intrusive_ptr
<
ScalarTypeTorch
>
;
// "rust style" names generally following:
...
...
@@ -379,4 +513,5 @@ static inline constexpr auto kHalf = kFE5M10;
static
inline
constexpr
auto
kFloat16
=
kHalf
;
static
inline
constexpr
auto
kBFloat16
=
kFE8M7
;
static
inline
constexpr
auto
kFloat16Id
=
kFloat16
.
id
();
};
// namespace vllm
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
6aa33cb2
...
...
@@ -42,8 +42,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
block_rows
)
{}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
...
...
@@ -151,20 +151,21 @@ __device__ inline uint32_t prmt(uint32_t a) {
return
res
;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
template
<
typename
scalar_t
,
vllm
::
ScalarTypeId
w_type_id
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant
(
int
q
);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit
<
half
>
(
int
q
)
{
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU4B8
.
id
()
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
...
...
@@ -187,7 +188,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
_4bit
<
nv_bfloat16
>
(
int
q
)
{
dequant
<
nv_bfloat16
,
vllm
::
kU4B8
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
...
...
@@ -210,19 +211,64 @@ dequant_4bit<nv_bfloat16>(int q) {
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU4
.
id
()
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
<
nv_bfloat16
,
vllm
::
kU4
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
//
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit
<
half
>
(
int
q
)
{
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU8B128
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
...
...
@@ -242,7 +288,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
_8bit
<
nv_bfloat16
>
(
int
q
)
{
dequant
<
nv_bfloat16
,
vllm
::
kU8B128
.
id
()
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
...
...
@@ -269,68 +315,9 @@ dequant_8bit<nv_bfloat16>(int q) {
return
frag_b
;
}
// Zero-point dequantizers
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_4bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_4bit_zp
<
half
>
(
int
q
)
{
const
int
LO
=
0x000f000f
;
const
int
HI
=
0x00f000f0
;
const
int
EX
=
0x64006400
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
LO
,
EX
);
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
HI
,
EX
);
const
int
SUB
=
0x64006400
;
const
int
MUL
=
0x2c002c00
;
const
int
ADD
=
0xd400d400
;
typename
ScalarType
<
half
>::
FragB
frag_b
;
frag_b
[
0
]
=
__hsub2
(
*
reinterpret_cast
<
half2
*>
(
&
lo
),
*
reinterpret_cast
<
const
half2
*>
(
&
SUB
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
half2
*>
(
&
hi
),
*
reinterpret_cast
<
const
half2
*>
(
&
MUL
),
*
reinterpret_cast
<
const
half2
*>
(
&
ADD
));
return
frag_b
;
}
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant_4bit_zp
<
nv_bfloat16
>
(
int
q
)
{
static
constexpr
uint32_t
MASK
=
0x000f000f
;
static
constexpr
uint32_t
EX
=
0x43004300
;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int
lo
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
q
>>=
4
;
int
hi
=
lop3
<
(
0xf0
&
0xcc
)
|
0xaa
>
(
q
,
MASK
,
EX
);
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
static
constexpr
uint32_t
MUL
=
0x3F803F80
;
static
constexpr
uint32_t
ADD
=
0xC300C300
;
frag_b
[
0
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
lo
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
frag_b
[
1
]
=
__hfma2
(
*
reinterpret_cast
<
nv_bfloat162
*>
(
&
hi
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
MUL
),
*
reinterpret_cast
<
const
nv_bfloat162
*>
(
&
ADD
));
return
frag_b
;
}
template
<
typename
scalar_t
>
__device__
inline
typename
ScalarType
<
scalar_t
>::
FragB
dequant_8bit_zp
(
int
q
)
{
STATIC_ASSERT_SCALAR_TYPE_VALID
(
scalar_t
);
}
template
<
>
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant_8bit_zp
<
half
>
(
int
q
)
{
__device__
inline
typename
ScalarType
<
half
>::
FragB
dequant
<
half
,
vllm
::
kU8
.
id
()
>
(
int
q
)
{
static
constexpr
uint32_t
mask_for_elt_01
=
0x5250
;
static
constexpr
uint32_t
mask_for_elt_23
=
0x5351
;
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
...
...
@@ -350,7 +337,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
template
<
>
__device__
inline
typename
ScalarType
<
nv_bfloat16
>::
FragB
dequant
_8bit_zp
<
nv_bfloat16
>
(
int
q
)
{
dequant
<
nv_bfloat16
,
vllm
::
kU8
.
id
()
>
(
int
q
)
{
typename
ScalarType
<
nv_bfloat16
>::
FragB
frag_b
;
float
fp32_intermediates
[
4
];
...
...
@@ -517,8 +504,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}
}
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
int
num_bits
,
// number of bits used for weights
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
...
...
@@ -568,7 +555,9 @@ __global__ void Marlin(
using
FragS
=
typename
ScalarType
<
scalar_t
>::
FragS
;
using
FragZP
=
typename
ScalarType
<
scalar_t
>::
FragZP
;
constexpr
int
pack_factor
=
32
/
num_bits
;
static
constexpr
auto
w_type
=
vllm
::
ScalarType
::
from_id
(
w_type_id
);
constexpr
int
pack_factor
=
32
/
w_type
.
size_bits
();
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
...
...
@@ -670,7 +659,7 @@ __global__ void Marlin(
// B sizes/strides
int
b_gl_stride
=
16
*
prob_n
/
(
pack_factor
*
4
);
constexpr
int
b_sh_stride
=
((
thread_n_blocks
*
16
)
*
16
/
pack_factor
)
/
4
;
constexpr
int
b_thread_vecs
=
num
_bits
==
4
?
1
:
2
;
constexpr
int
b_thread_vecs
=
w_type
.
size
_bits
()
==
4
?
1
:
2
;
constexpr
int
b_sh_stride_threads
=
b_sh_stride
/
b_thread_vecs
;
int
b_gl_rd_delta_o
=
b_gl_stride
*
thread_k_blocks
;
...
...
@@ -1186,19 +1175,20 @@ __global__ void Marlin(
if
constexpr
(
has_zp
)
{
FragB
frag_zp_0
;
FragB
frag_zp_1
;
if
constexpr
(
num_bits
==
4
)
{
int
zp_quant
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_shift
=
zp_quant
>>
8
;
frag_zp_0
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant
);
frag_zp_1
=
dequant_4bit_zp
<
scalar_t
>
(
zp_quant_shift
);
int
zp_quant_0
,
zp_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
zp_quant_1
=
zp_quant_0
>>
8
;
}
else
{
int
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
int
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
frag_zp_0
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_0
);
frag_zp_1
=
dequant_8bit_zp
<
scalar_t
>
(
zp_quant_1
);
static_assert
(
w_type
.
size_bits
()
==
8
);
zp_quant_0
=
frag_qzp
[
k
%
2
][
0
];
zp_quant_1
=
frag_qzp
[
k
%
2
][
1
];
}
frag_zp_0
=
dequant
<
scalar_t
,
w_type_id
>
(
zp_quant_0
);
frag_zp_1
=
dequant
<
scalar_t
,
w_type_id
>
(
zp_quant_1
);
frag_zp
[
0
]
=
frag_zp_0
[
0
];
frag_zp
[
1
]
=
frag_zp_0
[
1
];
frag_zp
[
2
]
=
frag_zp_1
[
0
];
...
...
@@ -1211,33 +1201,21 @@ __global__ void Marlin(
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
FragB
frag_b0
;
FragB
frag_b1
;
if
constexpr
(
num_bits
==
4
)
{
int
b_quant
=
frag_b_quant
[
k
%
2
][
0
][
j
];
int
b_quant_shift
=
b_quant
>>
8
;
if
constexpr
(
has_zp
)
{
frag_b0
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit_zp
<
scalar_t
>
(
b_quant_shift
);
}
else
{
frag_b0
=
dequant_4bit
<
scalar_t
>
(
b_quant
);
frag_b1
=
dequant_4bit
<
scalar_t
>
(
b_quant_shift
);
}
int
b_quant_0
,
b_quant_1
;
if
constexpr
(
w_type
.
size_bits
()
==
4
)
{
b_quant_0
=
frag_b_quant
[
k
%
2
][
0
][
j
];
b_quant_1
=
b_quant_0
>>
8
;
}
else
{
static_assert
(
w_type
.
size_bits
()
==
8
);
int
*
frag_b_quant_ptr
=
reinterpret_cast
<
int
*>
(
frag_b_quant
[
k
%
2
]);
int
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
int
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
if
constexpr
(
has_zp
)
{
frag_b0
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit_zp
<
scalar_t
>
(
b_quant_1
);
}
else
{
frag_b0
=
dequant_8bit
<
scalar_t
>
(
b_quant_0
);
frag_b1
=
dequant_8bit
<
scalar_t
>
(
b_quant_1
);
}
b_quant_0
=
frag_b_quant_ptr
[
j
*
2
+
0
];
b_quant_1
=
frag_b_quant_ptr
[
j
*
2
+
1
];
}
frag_b0
=
dequant
<
scalar_t
,
w_type_id
>
(
b_quant_0
);
frag_b1
=
dequant
<
scalar_t
,
w_type_id
>
(
b_quant_1
);
// Apply zero-point to frag_b0
if
constexpr
(
has_zp
)
{
sub_zp
<
scalar_t
>
(
frag_b0
,
frag_zp
[
j
],
0
);
...
...
@@ -1477,7 +1455,8 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
4
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
4
)
{
res
=
__hmul2
(
res
,
s
[
0
]);
}
...
...
@@ -1605,7 +1584,7 @@ __global__ void Marlin(
// For per-column scales, we only fetch them here in the final step before
// write-out
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num
_bits
==
8
)
{
if
constexpr
(
w_type
.
size
_bits
()
==
8
)
{
if
(
s_sh_wr_pred
)
{
cp_async4
(
&
sh_s
[
s_sh_wr
],
&
scales_ptr
[
s_gl_rd
]);
}
...
...
@@ -1622,7 +1601,7 @@ __global__ void Marlin(
thread_block_reduce
();
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
)
{
if
constexpr
(
num
_bits
==
8
)
{
if
constexpr
(
w_type
.
size
_bits
()
==
8
)
{
cp_async_wait
<
0
>
();
__syncthreads
();
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
...
...
@@ -1645,7 +1624,8 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
num_bits
==
8
)
{
if
constexpr
(
!
has_act_order
&&
group_blocks
==
-
1
&&
w_type
.
size_bits
()
==
8
)
{
if
(
threadIdx
.
x
/
32
<
thread_n_blocks
/
4
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
thread_m_blocks
;
i
++
)
{
...
...
@@ -1714,20 +1694,19 @@ __global__ void Marlin(
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t,
NUM_BITS
, NUM_THREADS, THREAD_M_BLOCKS,
\
Marlin<scalar_t,
W_TYPE.id()
, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t,
NUM_BITS
, NUM_THREADS, THREAD_M_BLOCKS,
\
Marlin<scalar_t,
W_TYPE.id()
, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
...
...
@@ -1923,52 +1902,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return
exec_config_t
{
0
,
{
-
1
,
-
1
,
-
1
}};
}
#define GPTQ_CALL_IF(
NUM_BITS
, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(
NUM_BITS
, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
NUM_BITS
, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
#define GPTQ_CALL_IF(
W_TYPE
, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(
W_TYPE
, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(
W_TYPE
, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template
<
typename
scalar_t
>
void
marlin_mm
(
const
void
*
A
,
const
void
*
B
,
void
*
C
,
void
*
C_tmp
,
void
*
s
,
...
...
@@ -2113,23 +2092,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
if
(
false
)
{
}
GPTQ_CALL_IF
(
4
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
4
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
4
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
8
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
8
,
4
,
8
,
128
)
AWQ_CALL_IF
(
4
,
16
,
4
,
256
)
AWQ_CALL_IF
(
4
,
8
,
8
,
256
)
AWQ_CALL_IF
(
4
,
8
,
4
,
128
)
AWQ_CALL_IF
(
4
,
4
,
8
,
128
)
AWQ_CALL_IF
(
8
,
16
,
4
,
256
)
AWQ_CALL_IF
(
8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
8
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU4B8
,
4
,
8
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU8B12
8
,
16
,
4
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU8B12
8
,
8
,
8
,
256
)
GPTQ_CALL_IF
(
vllm
::
kU8B12
8
,
8
,
4
,
128
)
GPTQ_CALL_IF
(
vllm
::
kU8B12
8
,
4
,
8
,
128
)
AWQ_CALL_IF
(
vllm
::
kU
4
,
16
,
4
,
256
)
AWQ_CALL_IF
(
vllm
::
kU
4
,
8
,
8
,
256
)
AWQ_CALL_IF
(
vllm
::
kU
4
,
8
,
4
,
128
)
AWQ_CALL_IF
(
vllm
::
kU
4
,
4
,
8
,
128
)
AWQ_CALL_IF
(
vllm
::
kU
8
,
16
,
4
,
256
)
AWQ_CALL_IF
(
vllm
::
kU
8
,
8
,
8
,
256
)
AWQ_CALL_IF
(
vllm
::
kU
8
,
8
,
4
,
128
)
AWQ_CALL_IF
(
vllm
::
kU
8
,
4
,
8
,
128
)
else
{
TORCH_CHECK
(
false
,
"Unsupported shapes: MNK = ["
,
prob_m
,
", "
,
prob_n
,
", "
,
prob_k
,
"]"
,
", has_act_order = "
,
has_act_order
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment