Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7759ae95
Unverified
Commit
7759ae95
authored
Aug 16, 2024
by
bnellnm
Committed by
GitHub
Aug 16, 2024
Browse files
[Kernel][Misc] dynamo support for ScalarType (#7594)
parent
9f698563
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
149 additions
and
24 deletions
+149
-24
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+30
-0
vllm/_core_ext.py
vllm/_core_ext.py
+119
-24
No files found.
csrc/core/scalar_type.hpp
View file @
7759ae95
...
@@ -313,6 +313,8 @@ class ScalarType {
...
@@ -313,6 +313,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class
ScalarTypeTorch
:
public
torch
::
CustomClassHolder
,
public
ScalarType
{
class
ScalarTypeTorch
:
public
torch
::
CustomClassHolder
,
public
ScalarType
{
public:
public:
ScalarTypeTorch
(
int64_t
exponent
,
int64_t
mantissa
,
int64_t
bias
,
ScalarTypeTorch
(
int64_t
exponent
,
int64_t
mantissa
,
int64_t
bias
,
...
@@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
...
@@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
}
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t
len
()
const
{
throw
c10
::
TypeError
(
"__len__ not implemented"
);
return
0
;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
obj_flatten
()
const
{
return
{{
"ScalarType"
,
id
()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static
SelfPtr
obj_unflatten
(
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
const
&
flat_type
)
{
return
c10
::
make_intrusive
<
Self
>
(
from_id
(
std
::
get
<
1
>
(
std
::
get
<
0
>
(
flat_type
))));
}
template
<
typename
T
>
template
<
typename
T
>
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
std
::
string
const
&
name
,
T
Base
::*
field
)
{
std
::
string
const
&
name
,
T
Base
::*
field
)
{
...
@@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
...
@@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self
.
get
()
->
min
());
self
.
get
()
->
min
());
});
});
bind_function
(
cls
,
"__len__"
,
&
ScalarTypeTorch
::
len
);
bind_function
(
cls
,
"__str__"
,
&
Base
::
str
);
bind_function
(
cls
,
"__str__"
,
&
Base
::
str
);
bind_function
(
cls
,
"__eq__"
,
[](
SelfPtr
const
&
self
,
SelfPtr
const
&
other
)
{
bind_function
(
cls
,
"__eq__"
,
[](
SelfPtr
const
&
self
,
SelfPtr
const
&
other
)
{
return
*
self
==
*
other
;
return
*
self
==
*
other
;
...
@@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
...
@@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return
"ScalarType."
+
self
.
get
()
->
str
();
return
"ScalarType."
+
self
.
get
()
->
str
();
});
});
bind_function
(
cls
,
"__obj_flatten__"
,
&
ScalarTypeTorch
::
obj_flatten
);
bind_static_function
(
cls
,
"__obj_unflatten__"
,
&
ScalarTypeTorch
::
obj_unflatten
);
// Bind static functions (convenience constructors)
// Bind static functions (convenience constructors)
bind_static_function
(
cls
,
"int_"
,
&
ScalarTypeTorch
::
int_
);
bind_static_function
(
cls
,
"int_"
,
&
ScalarTypeTorch
::
int_
);
bind_static_function
(
cls
,
"uint"
,
&
ScalarTypeTorch
::
uint
);
bind_static_function
(
cls
,
"uint"
,
&
ScalarTypeTorch
::
uint
);
...
...
vllm/_core_ext.py
View file @
7759ae95
import
importlib.util
import
importlib.util
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -31,14 +31,14 @@ if TYPE_CHECKING or not core_C_available:
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
ScalarType
:
class
ScalarType
:
"""
"""
ScalarType can represent a wide range of floating point and integer
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
with that file.
"""
"""
...
@@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -51,15 +51,15 @@ if TYPE_CHECKING or not core_C_available:
mantissa
:
int
mantissa
:
int
"""
"""
Number of bits in the mantissa if this is a floating point type,
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
or the number bits representing an integer excluding the sign bit if
this an integer type.
this an integer type.
"""
"""
bias
:
int
bias
:
int
"""
"""
bias used to encode the values in this scalar type
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
"""
...
@@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -73,7 +73,7 @@ if TYPE_CHECKING or not core_C_available:
nan_repr
:
int
=
NanRepr
.
IEEE_754
.
value
nan_repr
:
int
=
NanRepr
.
IEEE_754
.
value
"""
"""
How NaNs are represent in this scalar type, returns NanRepr value.
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
(not applicable for integer types)
"""
"""
...
@@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -83,14 +83,14 @@ if TYPE_CHECKING or not core_C_available:
def
min
(
self
)
->
Union
[
int
,
float
]:
def
min
(
self
)
->
Union
[
int
,
float
]:
"""
"""
Min representable value for this scalar type.
Min representable value for this scalar type.
(accounting for bias if there is one)
(accounting for bias if there is one)
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
max
(
self
)
->
Union
[
int
,
float
]:
def
max
(
self
)
->
Union
[
int
,
float
]:
"""
"""
Max representable value for this scalar type.
Max representable value for this scalar type.
(accounting for bias if there is one)
(accounting for bias if there is one)
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -103,28 +103,28 @@ if TYPE_CHECKING or not core_C_available:
"""
"""
...
...
def
is_floating_point
(
self
):
def
is_floating_point
(
self
)
->
bool
:
"If the type is a floating point type"
"If the type is a floating point type"
return
self
.
exponent
!=
0
return
self
.
exponent
!=
0
def
is_integer
(
self
):
def
is_integer
(
self
)
->
bool
:
"If the type is an integer type"
"If the type is an integer type"
return
self
.
exponent
==
0
return
self
.
exponent
==
0
def
has_bias
(
self
):
def
has_bias
(
self
)
->
bool
:
"If the type has a non-zero bias"
"If the type has a non-zero bias"
return
self
.
bias
!=
0
return
self
.
bias
!=
0
def
has_infs
(
self
):
def
has_infs
(
self
)
->
bool
:
"If the type is floating point and supports infinity"
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
return
not
self
.
_finite_values_only
def
has_nans
(
self
):
def
has_nans
(
self
)
->
bool
:
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
def
is_ieee_754
(
self
)
->
bool
:
"""
"""
If the type is a floating point type that follows IEEE 754
If the type is a floating point type that follows IEEE 754
conventions
conventions
"""
"""
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
\
return
self
.
nan_repr
==
NanRepr
.
IEEE_754
.
value
and
\
...
@@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
raise
NotImplementedError
raise
NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def
__len__
(
self
)
->
int
:
raise
TypeError
#
#
# Convenience Constructors
# Convenience Constructors
#
#
...
@@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available:
...
@@ -153,16 +158,16 @@ if TYPE_CHECKING or not core_C_available:
@
classmethod
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
"""
"""
Create a standard floating point type
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
(i.e. follows IEEE 754 conventions).
"""
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
)
return
cls
(
exponent
,
mantissa
,
0
,
True
)
@
classmethod
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
):
nan_repr
:
int
)
->
'ScalarType'
:
"""
"""
Create a non-standard floating point type
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
(i.e. does not follow IEEE 754 conventions).
"""
"""
return
cls
(
exponent
,
mantissa
,
0
,
True
,
finite_values_only
,
return
cls
(
exponent
,
mantissa
,
0
,
True
,
finite_values_only
,
...
@@ -175,3 +180,93 @@ elif core_C_available:
...
@@ -175,3 +180,93 @@ elif core_C_available:
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
# Needed for dynamo support of ScalarType.
@
torch
.
_library
.
register_fake_class
(
"_core_C::ScalarType"
)
class
FakeScalarType
:
def
__init__
(
self
,
scalar_type
):
self
.
ScalarType
=
scalar_type
def
bias_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
bias
def
exponent_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
exponent
def
mantissa_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
mantissa
def
signed_getter
(
self
)
->
bool
:
return
self
.
ScalarType
.
signed
def
size_bits_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
@
property
def
size_bits
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
def
min
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
min
()
def
max
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
max
()
def
is_signed
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_signed
()
def
is_floating_point
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_floating_point
()
def
is_integer
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_integer
()
def
has_bias
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_bias
()
def
has_infs
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_infs
()
def
has_nans
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_nans
()
def
is_ieee_754
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_ieee_754
()
def
__str__
(
self
)
->
str
:
return
self
.
ScalarType
.
__str__
()
def
__repr__
(
self
)
->
str
:
return
self
.
ScalarType
.
__repr__
()
def
__len__
(
self
)
->
int
:
return
self
.
ScalarType
.
__len__
()
def
__obj_flatten__
(
self
)
->
Tuple
[
Tuple
[
str
,
Any
],
...]:
return
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_flatten__
(
self
.
ScalarType
)
@
classmethod
def
__obj_unflatten__
(
cls
,
flat_type
:
Tuple
[
Tuple
[
str
,
Any
],
...])
->
'ScalarType'
:
return
cls
(
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_unflatten__
(
flat_type
))
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
int_
(
size_bits
,
bias
)
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
uint
(
size_bits
,
bias
)
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_IEEE754
(
exponent
,
mantissa
)
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_
(
exponent
,
mantissa
,
finite_values_only
,
nan_repr
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment