Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7759ae95
Unverified
Commit
7759ae95
authored
Aug 16, 2024
by
bnellnm
Committed by
GitHub
Aug 16, 2024
Browse files
[Kernel][Misc] dynamo support for ScalarType (#7594)
parent
9f698563
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
149 additions
and
24 deletions
+149
-24
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+30
-0
vllm/_core_ext.py
vllm/_core_ext.py
+119
-24
No files found.
csrc/core/scalar_type.hpp
View file @
7759ae95
...
...
@@ -313,6 +313,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class
ScalarTypeTorch
:
public
torch
::
CustomClassHolder
,
public
ScalarType
{
public:
ScalarTypeTorch
(
int64_t
exponent
,
int64_t
mantissa
,
int64_t
bias
,
...
...
@@ -382,6 +384,29 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t
len
()
const
{
throw
c10
::
TypeError
(
"__len__ not implemented"
);
return
0
;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
obj_flatten
()
const
{
return
{{
"ScalarType"
,
id
()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static
SelfPtr
obj_unflatten
(
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
const
&
flat_type
)
{
return
c10
::
make_intrusive
<
Self
>
(
from_id
(
std
::
get
<
1
>
(
std
::
get
<
0
>
(
flat_type
))));
}
template
<
typename
T
>
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
std
::
string
const
&
name
,
T
Base
::*
field
)
{
...
...
@@ -457,6 +482,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self
.
get
()
->
min
());
});
bind_function
(
cls
,
"__len__"
,
&
ScalarTypeTorch
::
len
);
bind_function
(
cls
,
"__str__"
,
&
Base
::
str
);
bind_function
(
cls
,
"__eq__"
,
[](
SelfPtr
const
&
self
,
SelfPtr
const
&
other
)
{
return
*
self
==
*
other
;
...
...
@@ -465,6 +491,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return
"ScalarType."
+
self
.
get
()
->
str
();
});
bind_function
(
cls
,
"__obj_flatten__"
,
&
ScalarTypeTorch
::
obj_flatten
);
bind_static_function
(
cls
,
"__obj_unflatten__"
,
&
ScalarTypeTorch
::
obj_unflatten
);
// Bind static functions (convenience constructors)
bind_static_function
(
cls
,
"int_"
,
&
ScalarTypeTorch
::
int_
);
bind_static_function
(
cls
,
"uint"
,
&
ScalarTypeTorch
::
uint
);
...
...
vllm/_core_ext.py
View file @
7759ae95
import
importlib.util
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -103,23 +103,23 @@ if TYPE_CHECKING or not core_C_available:
"""
...
def
is_floating_point
(
self
):
def
is_floating_point
(
self
)
->
bool
:
"If the type is a floating point type"
return
self
.
exponent
!=
0
def
is_integer
(
self
):
def
is_integer
(
self
)
->
bool
:
"If the type is an integer type"
return
self
.
exponent
==
0
def
has_bias
(
self
):
def
has_bias
(
self
)
->
bool
:
"If the type has a non-zero bias"
return
self
.
bias
!=
0
def
has_infs
(
self
):
def
has_infs
(
self
)
->
bool
:
"If the type is floating point and supports infinity"
return
not
self
.
_finite_values_only
def
has_nans
(
self
):
def
has_nans
(
self
)
->
bool
:
return
self
.
nan_repr
!=
NanRepr
.
NONE
.
value
def
is_ieee_754
(
self
)
->
bool
:
...
...
@@ -136,6 +136,11 @@ if TYPE_CHECKING or not core_C_available:
def
__repr__
(
self
)
->
str
:
raise
NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def
__len__
(
self
)
->
int
:
raise
TypeError
#
# Convenience Constructors
#
...
...
@@ -160,7 +165,7 @@ if TYPE_CHECKING or not core_C_available:
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
):
nan_repr
:
int
)
->
'ScalarType'
:
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
...
...
@@ -175,3 +180,93 @@ elif core_C_available:
logger
.
warning
(
"Failed to import from vllm._core_C with %r"
,
e
)
ScalarType
=
torch
.
classes
.
_core_C
.
ScalarType
# Needed for dynamo support of ScalarType.
@
torch
.
_library
.
register_fake_class
(
"_core_C::ScalarType"
)
class
FakeScalarType
:
def
__init__
(
self
,
scalar_type
):
self
.
ScalarType
=
scalar_type
def
bias_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
bias
def
exponent_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
exponent
def
mantissa_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
mantissa
def
signed_getter
(
self
)
->
bool
:
return
self
.
ScalarType
.
signed
def
size_bits_getter
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
@
property
def
size_bits
(
self
)
->
int
:
return
self
.
ScalarType
.
size_bits
def
min
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
min
()
def
max
(
self
)
->
Union
[
int
,
float
]:
return
self
.
ScalarType
.
max
()
def
is_signed
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_signed
()
def
is_floating_point
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_floating_point
()
def
is_integer
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_integer
()
def
has_bias
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_bias
()
def
has_infs
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_infs
()
def
has_nans
(
self
)
->
bool
:
return
self
.
ScalarType
.
has_nans
()
def
is_ieee_754
(
self
)
->
bool
:
return
self
.
ScalarType
.
is_ieee_754
()
def
__str__
(
self
)
->
str
:
return
self
.
ScalarType
.
__str__
()
def
__repr__
(
self
)
->
str
:
return
self
.
ScalarType
.
__repr__
()
def
__len__
(
self
)
->
int
:
return
self
.
ScalarType
.
__len__
()
def
__obj_flatten__
(
self
)
->
Tuple
[
Tuple
[
str
,
Any
],
...]:
return
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_flatten__
(
self
.
ScalarType
)
@
classmethod
def
__obj_unflatten__
(
cls
,
flat_type
:
Tuple
[
Tuple
[
str
,
Any
],
...])
->
'ScalarType'
:
return
cls
(
torch
.
classes
.
_core_C
.
ScalarType
.
__obj_unflatten__
(
flat_type
))
@
classmethod
def
int_
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
int_
(
size_bits
,
bias
)
@
classmethod
def
uint
(
cls
,
size_bits
:
int
,
bias
:
Optional
[
int
])
->
'ScalarType'
:
return
ScalarType
.
uint
(
size_bits
,
bias
)
@
classmethod
def
float_IEEE754
(
cls
,
exponent
:
int
,
mantissa
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_IEEE754
(
exponent
,
mantissa
)
@
classmethod
def
float_
(
cls
,
exponent
:
int
,
mantissa
:
int
,
finite_values_only
:
bool
,
nan_repr
:
int
)
->
'ScalarType'
:
return
ScalarType
.
float_
(
exponent
,
mantissa
,
finite_values_only
,
nan_repr
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment