Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a9a2e74d
Unverified
Commit
a9a2e74d
authored
Jul 17, 2024
by
Woosuk Kwon
Committed by
GitHub
Jul 17, 2024
Browse files
[Misc] Use `torch.Tensor` for type annotation (#6505)
parent
e09ce759
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
18 deletions
+18
-18
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+17
-17
vllm/worker/worker.py
vllm/worker/worker.py
+1
-1
No files found.
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
a9a2e74d
...
@@ -20,18 +20,18 @@ DEFAULT_TP_SIZES = [1]
...
@@ -20,18 +20,18 @@ DEFAULT_TP_SIZES = [1]
# helpers
# helpers
def
to_fp8
(
tensor
:
torch
.
t
ensor
)
->
torch
.
t
ensor
:
def
to_fp8
(
tensor
:
torch
.
T
ensor
)
->
torch
.
T
ensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
t
ensor
)
->
torch
.
t
ensor
:
def
to_int8
(
tensor
:
torch
.
T
ensor
)
->
torch
.
T
ensor
:
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
make_rand_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
def
make_rand_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
Tuple
[
torch
.
t
ensor
,
torch
.
t
ensor
]:
k
:
int
)
->
Tuple
[
torch
.
T
ensor
,
torch
.
T
ensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
*
5
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
*
5
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
*
5
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
*
5
...
@@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
...
@@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl
# impl
def
pytorch_mm_impl
(
a
:
torch
.
t
ensor
,
b
:
torch
.
t
ensor
,
scale_a
:
torch
.
t
ensor
,
def
pytorch_mm_impl
(
a
:
torch
.
T
ensor
,
b
:
torch
.
T
ensor
,
scale_a
:
torch
.
T
ensor
,
scale_b
:
torch
.
t
ensor
,
scale_b
:
torch
.
T
ensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
t
ensor
:
out_dtype
:
torch
.
dtype
)
->
torch
.
T
ensor
:
return
torch
.
mm
(
a
,
b
)
return
torch
.
mm
(
a
,
b
)
def
pytorch_fp8_impl
(
a
:
torch
.
t
ensor
,
b
:
torch
.
t
ensor
,
scale_a
:
torch
.
t
ensor
,
def
pytorch_fp8_impl
(
a
:
torch
.
T
ensor
,
b
:
torch
.
T
ensor
,
scale_a
:
torch
.
T
ensor
,
scale_b
:
torch
.
t
ensor
,
scale_b
:
torch
.
T
ensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
t
ensor
:
out_dtype
:
torch
.
dtype
)
->
torch
.
T
ensor
:
return
torch
.
_scaled_mm
(
a
,
return
torch
.
_scaled_mm
(
a
,
b
,
b
,
scale_a
=
scale_a
,
scale_a
=
scale_a
,
...
@@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
...
@@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
out_dtype
=
out_dtype
)
out_dtype
=
out_dtype
)
def
pytorch_fp8_impl_fast_accum
(
a
:
torch
.
t
ensor
,
b
:
torch
.
t
ensor
,
def
pytorch_fp8_impl_fast_accum
(
a
:
torch
.
T
ensor
,
b
:
torch
.
T
ensor
,
scale_a
:
torch
.
t
ensor
,
scale_b
:
torch
.
t
ensor
,
scale_a
:
torch
.
T
ensor
,
scale_b
:
torch
.
T
ensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
t
ensor
:
out_dtype
:
torch
.
dtype
)
->
torch
.
T
ensor
:
return
torch
.
_scaled_mm
(
a
,
return
torch
.
_scaled_mm
(
a
,
b
,
b
,
scale_a
=
scale_a
,
scale_a
=
scale_a
,
...
@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
...
@@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
use_fast_accum
=
True
)
use_fast_accum
=
True
)
def
cutlass_impl
(
a
:
torch
.
t
ensor
,
b
:
torch
.
t
ensor
,
scale_a
:
torch
.
t
ensor
,
def
cutlass_impl
(
a
:
torch
.
T
ensor
,
b
:
torch
.
T
ensor
,
scale_a
:
torch
.
T
ensor
,
scale_b
:
torch
.
t
ensor
,
scale_b
:
torch
.
T
ensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
t
ensor
:
out_dtype
:
torch
.
dtype
)
->
torch
.
T
ensor
:
return
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
)
return
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
)
# bench
# bench
def
bench_fn
(
a
:
torch
.
t
ensor
,
b
:
torch
.
t
ensor
,
scale_a
:
torch
.
t
ensor
,
def
bench_fn
(
a
:
torch
.
T
ensor
,
b
:
torch
.
T
ensor
,
scale_a
:
torch
.
T
ensor
,
scale_b
:
torch
.
t
ensor
,
out_dtype
:
torch
.
dtype
,
label
:
str
,
scale_b
:
torch
.
T
ensor
,
out_dtype
:
torch
.
dtype
,
label
:
str
,
sub_label
:
str
,
fn
:
Callable
,
description
:
str
)
->
TMeasurement
:
sub_label
:
str
,
fn
:
Callable
,
description
:
str
)
->
TMeasurement
:
min_run_time
=
1
min_run_time
=
1
...
...
vllm/worker/worker.py
View file @
a9a2e74d
...
@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
# initialize_cache.
# initialize_cache.
self
.
cache_engine
:
List
[
CacheEngine
]
self
.
cache_engine
:
List
[
CacheEngine
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
t
ensor
]]]
=
None
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
T
ensor
]]]
=
None
def
init_device
(
self
)
->
None
:
def
init_device
(
self
)
->
None
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment