Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6aa33cb2
Unverified
Commit
6aa33cb2
authored
Aug 12, 2024
by
Lucas Wilkinson
Committed by
GitHub
Aug 12, 2024
Browse files
[Misc] Use scalar type to dispatch to different `gptq_marlin` kernels (#7323)
parent
1137f343
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
334 additions
and
220 deletions
+334
-220
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+166
-31
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+168
-189
No files found.
csrc/core/scalar_type.hpp
View file @
6aa33cb2
...
...
@@ -20,7 +20,7 @@ namespace vllm {
//
class
ScalarType
{
public:
enum
NanRepr
:
int
64
_t
{
enum
NanRepr
:
u
int
8
_t
{
NAN_NONE
=
0
,
// nans are not supported
NAN_IEEE_754
=
1
,
// nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN
=
2
,
// nans are: exp all 1s, mantissa all 1s
...
...
@@ -28,33 +28,33 @@ class ScalarType {
NAN_REPR_ID_MAX
};
constexpr
ScalarType
(
bool
signed_
,
int
64
_t
exponent
,
int
64
_t
mantissa
,
int
64
_t
bias
,
bool
finite_values_only
=
false
,
constexpr
ScalarType
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
,
bool
signed_
,
int
32
_t
bias
,
bool
finite_values_only
=
false
,
NanRepr
nan_repr
=
NAN_IEEE_754
)
:
exponent
(
exponent
),
mantissa
(
mantissa
),
bias
(
bias
),
signed_
(
signed_
),
bias
(
bias
),
finite_values_only
(
finite_values_only
),
nan_repr
(
nan_repr
){};
static
constexpr
ScalarType
int_
(
int
64
_t
size_bits
,
int
64
_t
bias
=
0
)
{
return
ScalarType
(
true
,
0
,
size_bits
-
1
,
bias
);
static
constexpr
ScalarType
int_
(
u
int
8
_t
size_bits
,
int
32
_t
bias
=
0
)
{
return
ScalarType
(
0
,
size_bits
-
1
,
true
,
bias
);
}
static
constexpr
ScalarType
uint
(
int
64
_t
size_bits
,
int
64
_t
bias
=
0
)
{
return
ScalarType
(
false
,
0
,
size_bits
,
bias
);
static
constexpr
ScalarType
uint
(
u
int
8
_t
size_bits
,
int
32
_t
bias
=
0
)
{
return
ScalarType
(
0
,
size_bits
,
false
,
bias
);
}
// IEEE 754 compliant floating point type
static
constexpr
ScalarType
float_IEEE754
(
int
64
_t
exponent
,
int
64
_t
mantissa
)
{
static
constexpr
ScalarType
float_IEEE754
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
)
{
TORCH_CHECK
(
mantissa
>
0
&&
exponent
>
0
);
return
ScalarType
(
true
,
exponent
,
mantissa
,
0
,
false
,
NAN_IEEE_754
);
return
ScalarType
(
exponent
,
mantissa
,
true
,
0
,
false
,
NAN_IEEE_754
);
}
// IEEE 754 non-compliant floating point type
static
constexpr
ScalarType
float_
(
int
64
_t
exponent
,
int
64
_t
mantissa
,
static
constexpr
ScalarType
float_
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
,
bool
finite_values_only
,
NanRepr
nan_repr
)
{
TORCH_CHECK
(
nan_repr
<
NAN_REPR_ID_MAX
,
"Invalid NanRepr"
);
...
...
@@ -62,36 +62,121 @@ class ScalarType {
TORCH_CHECK
(
nan_repr
!=
NAN_IEEE_754
,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
);
return
ScalarType
(
true
,
exponent
,
mantissa
,
0
,
finite_values_only
,
return
ScalarType
(
exponent
,
mantissa
,
true
,
0
,
finite_values_only
,
nan_repr
);
}
int
64
_t
const
exponent
;
// size of the exponent field (0 for integer types)
int
64
_t
const
mantissa
;
// size of the mantissa field (size of the integer
u
int
8
_t
const
exponent
;
// size of the exponent field (0 for integer types)
u
int
8
_t
const
mantissa
;
// size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
int64_t
const
bias
;
// stored values equal value + bias,
// used for quantized type
bool
const
signed_
;
// flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t
const
bias
;
// stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool
const
finite_values_only
;
// i.e. no +/-inf if true
NanRepr
const
nan_repr
;
// how NaNs are represented
// (not applicable for integer types)
int64_t
size_bits
()
const
{
return
mantissa
+
exponent
+
is_signed
();
}
bool
is_signed
()
const
{
return
signed_
;
}
bool
is_integer
()
const
{
return
exponent
==
0
;
}
bool
is_floating_point
()
const
{
return
exponent
>
0
;
}
bool
is_ieee_754
()
const
{
using
Id
=
int64_t
;
private:
// Field size in id
template
<
typename
T_
>
static
constexpr
size_t
member_id_field_width
()
{
using
T
=
std
::
decay_t
<
T_
>
;
return
std
::
is_same_v
<
T
,
bool
>
?
1
:
sizeof
(
T
)
*
8
;
}
template
<
typename
Fn
,
typename
Init
,
typename
Member
,
typename
...
Rest
>
static
constexpr
auto
reduce_members_helper
(
Fn
f
,
Init
val
,
Member
member
,
Rest
...
rest
)
{
auto
new_val
=
f
(
val
,
member
);
if
constexpr
(
sizeof
...(
rest
)
>
0
)
{
return
reduce_members_helper
(
f
,
new_val
,
rest
...);
}
else
{
return
new_val
;
};
}
template
<
typename
Fn
,
typename
Init
>
constexpr
auto
reduce_members
(
Fn
f
,
Init
init
)
const
{
// Should be in constructor order for `from_id`
return
reduce_members_helper
(
f
,
init
,
exponent
,
mantissa
,
signed_
,
bias
,
finite_values_only
,
nan_repr
);
};
template
<
typename
Fn
,
typename
Init
>
static
constexpr
auto
reduce_member_types
(
Fn
f
,
Init
init
)
{
constexpr
auto
dummy_type
=
ScalarType
(
0
,
0
,
false
,
0
,
false
,
NAN_NONE
);
return
dummy_type
.
reduce_members
(
f
,
init
);
};
static
constexpr
auto
id_size_bits
()
{
return
reduce_member_types
(
[](
int
acc
,
auto
member
)
->
int
{
return
acc
+
member_id_field_width
<
decltype
(
member
)
>
();
},
0
);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr
Id
id
()
const
{
static_assert
(
id_size_bits
()
<=
sizeof
(
Id
)
*
8
,
"ScalarType id is too large to be stored"
);
auto
or_and_advance
=
[](
std
::
pair
<
Id
,
uint32_t
>
result
,
auto
member
)
->
std
::
pair
<
Id
,
uint32_t
>
{
auto
[
id
,
bit_offset
]
=
result
;
auto
constexpr
bits
=
member_id_field_width
<
decltype
(
member
)
>
();
return
{
id
|
(
int64_t
(
member
)
&
((
uint64_t
(
1
)
<<
bits
)
-
1
))
<<
bit_offset
,
bit_offset
+
bits
};
};
return
reduce_members
(
or_and_advance
,
std
::
pair
<
Id
,
uint32_t
>
{}).
first
;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static
constexpr
ScalarType
from_id
(
Id
id
)
{
auto
extract_and_advance
=
[
id
](
auto
result
,
auto
member
)
{
using
T
=
decltype
(
member
);
auto
[
tuple
,
bit_offset
]
=
result
;
auto
constexpr
bits
=
member_id_field_width
<
T
>
();
auto
extracted_val
=
static_cast
<
T
>
((
int64_t
(
id
)
>>
bit_offset
)
&
((
uint64_t
(
1
)
<<
bits
)
-
1
));
auto
new_tuple
=
std
::
tuple_cat
(
tuple
,
std
::
make_tuple
(
extracted_val
));
return
std
::
pair
<
decltype
(
new_tuple
),
int
>
{
new_tuple
,
bit_offset
+
bits
};
};
auto
[
tuple_args
,
_
]
=
reduce_member_types
(
extract_and_advance
,
std
::
pair
<
std
::
tuple
<>
,
int
>
{});
return
std
::
apply
([](
auto
...
args
)
{
return
ScalarType
(
args
...);
},
tuple_args
);
}
constexpr
int64_t
size_bits
()
const
{
return
mantissa
+
exponent
+
is_signed
();
}
constexpr
bool
is_signed
()
const
{
return
signed_
;
}
constexpr
bool
is_integer
()
const
{
return
exponent
==
0
;
}
constexpr
bool
is_floating_point
()
const
{
return
exponent
>
0
;
}
constexpr
bool
is_ieee_754
()
const
{
return
is_floating_point
()
&&
finite_values_only
==
false
&&
nan_repr
==
NAN_IEEE_754
;
}
bool
has_nans
()
const
{
return
is_floating_point
()
&&
nan_repr
!=
NAN_NONE
;
}
bool
has_infs
()
const
{
constexpr
bool
has_nans
()
const
{
return
is_floating_point
()
&&
nan_repr
!=
NAN_NONE
;
}
constexpr
bool
has_infs
()
const
{
return
is_floating_point
()
&&
finite_values_only
==
false
;
}
bool
has_bias
()
const
{
return
bias
!=
0
;
}
constexpr
bool
has_bias
()
const
{
return
bias
!=
0
;
}
private:
double
_floating_point_max
()
const
{
...
...
@@ -131,7 +216,7 @@ class ScalarType {
return
*
reinterpret_cast
<
double
*>
(
&
double_raw
);
}
std
::
variant
<
int64_t
,
double
>
_raw_max
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
_raw_max
()
const
{
if
(
is_floating_point
())
{
return
{
_floating_point_max
()};
}
else
{
...
...
@@ -141,7 +226,7 @@ class ScalarType {
}
}
std
::
variant
<
int64_t
,
double
>
_raw_min
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
_raw_min
()
const
{
if
(
is_floating_point
())
{
TORCH_CHECK
(
is_signed
(),
"We currently assume all floating point types are signed"
);
...
...
@@ -168,7 +253,7 @@ class ScalarType {
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
std
::
variant
<
int64_t
,
double
>
max
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
max
()
const
{
return
std
::
visit
(
[
this
](
auto
x
)
->
std
::
variant
<
int64_t
,
double
>
{
return
{
x
-
bias
};
},
_raw_max
());
...
...
@@ -176,7 +261,7 @@ class ScalarType {
// Min representable value for this scalar type.
// (accounting for bias if there is one)
std
::
variant
<
int64_t
,
double
>
min
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
min
()
const
{
return
std
::
visit
(
[
this
](
auto
x
)
->
std
::
variant
<
int64_t
,
double
>
{
return
{
x
-
bias
};
},
_raw_min
());
...
...
@@ -215,7 +300,7 @@ class ScalarType {
}
}
bool
operator
==
(
ScalarType
const
&
other
)
const
{
constexpr
bool
operator
==
(
ScalarType
const
&
other
)
const
{
return
mantissa
==
other
.
mantissa
&&
exponent
==
other
.
exponent
&&
bias
==
other
.
bias
&&
signed_
==
other
.
signed_
&&
finite_values_only
==
other
.
finite_values_only
&&
...
...
@@ -240,23 +325,59 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
using
Self
=
ScalarTypeTorch
;
using
SelfPtr
=
c10
::
intrusive_ptr
<
Self
>
;
static
void
check_size_bits
(
int64_t
size_bits
,
bool
signed_
)
{
TORCH_CHECK
(
size_bits
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"size_bits bit width is too large to be represented"
);
}
static
void
check_bias
(
int64_t
bias
)
{
using
Bias
=
decltype
(
std
::
declval
<
Self
>
().
bias
);
TORCH_CHECK
(
bias
<=
std
::
numeric_limits
<
Bias
>::
max
()
&&
bias
>=
std
::
numeric_limits
<
Bias
>::
min
(),
"bias too large or small to be represented"
);
}
static
void
check_exponent
(
int64_t
exponent
)
{
TORCH_CHECK
(
exponent
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
exponent
)
>::
max
(),
"exponent bit width is too large to be represented"
);
}
static
void
check_mantissa
(
int64_t
mantissa
)
{
TORCH_CHECK
(
mantissa
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"mantissa bit width is too large to be represented"
);
}
static
SelfPtr
int_
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
int_
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
uint
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
uint
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
float_IEEE754
(
int64_t
exponent
,
int64_t
mantissa
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_IEEE754
(
exponent
,
mantissa
));
}
static
SelfPtr
float_
(
int64_t
exponent
,
int64_t
mantissa
,
bool
finite_values_only
,
int64_t
nan_repr
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_
(
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
}
...
...
@@ -264,7 +385,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
template
<
typename
T
>
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
std
::
string
const
&
name
,
T
Base
::*
field
)
{
auto
getter_func
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
auto
getter_func
_helper
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
if
constexpr
(
std
::
is_member_function_pointer_v
<
decltype
(
field
)
>
)
{
return
(
self
.
get
()
->*
field
)();
}
else
{
...
...
@@ -272,6 +393,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
auto
getter_func
=
[
field
=
std
::
move
(
field
),
getter_func_helper
=
std
::
move
(
getter_func_helper
)](
SelfPtr
const
&
self
)
{
auto
val
=
getter_func_helper
(
self
);
// upconvert uint8_t, int32_t etc. to int64_t for python
if
constexpr
(
std
::
is_integral_v
<
T
>
)
{
return
static_cast
<
int64_t
>
(
val
);
}
else
{
return
val
;
}
};
cls
.
def_property
(
name
,
getter_func
);
}
...
...
@@ -340,6 +473,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
using
ScalarTypeId
=
int64_t
;
using
ScalarTypeTorchPtr
=
c10
::
intrusive_ptr
<
ScalarTypeTorch
>
;
// "rust style" names generally following:
...
...
@@ -379,4 +513,5 @@ static inline constexpr auto kHalf = kFE5M10;
static
inline
constexpr
auto
kFloat16
=
kHalf
;
static
inline
constexpr
auto
kBFloat16
=
kFE8M7
;
static
inline
constexpr
auto
kFloat16Id
=
kFloat16
.
id
();
};
// namespace vllm
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
6aa33cb2
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment