Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ee92b58b
Unverified
Commit
ee92b58b
authored
Oct 07, 2023
by
Antoni Baum
Committed by
GitHub
Oct 07, 2023
Browse files
Move bfloat16 check to worker (#1259)
parent
09ff7f10
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
9 deletions
+14
-9
vllm/config.py
vllm/config.py
+0
-9
vllm/worker/worker.py
vllm/worker/worker.py
+14
-0
No files found.
vllm/config.py
View file @
ee92b58b
...
@@ -345,15 +345,6 @@ def _get_and_verify_dtype(
...
@@ -345,15 +345,6 @@ def _get_and_verify_dtype(
# Casting between float16 and bfloat16 is allowed with a warning.
# Casting between float16 and bfloat16 is allowed with a warning.
logger
.
warning
(
f
"Casting
{
config_dtype
}
to
{
torch_dtype
}
."
)
logger
.
warning
(
f
"Casting
{
config_dtype
}
to
{
torch_dtype
}
."
)
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
if
compute_capability
[
0
]
<
8
:
gpu_name
=
torch
.
cuda
.
get_device_name
()
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
return
torch_dtype
return
torch_dtype
...
...
vllm/worker/worker.py
View file @
ee92b58b
...
@@ -59,6 +59,8 @@ class Worker:
...
@@ -59,6 +59,8 @@ class Worker:
raise
ValueError
(
"Invalid or unspecified rank."
)
raise
ValueError
(
"Invalid or unspecified rank."
)
torch
.
cuda
.
set_device
(
self
.
device
)
torch
.
cuda
.
set_device
(
self
.
device
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
# Initialize the distributed environment.
# Initialize the distributed environment.
_init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
_init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
)
self
.
distributed_init_method
)
...
@@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
...
@@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
f
"(required shared memory
{
required_shared_mem
}
> "
f
"(required shared memory
{
required_shared_mem
}
> "
f
"available shared memory
{
max_shared_mem
}
). "
f
"available shared memory
{
max_shared_mem
}
). "
"This will be fixed in a future release."
)
"This will be fixed in a future release."
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
if
compute_capability
[
0
]
<
8
:
gpu_name
=
torch
.
cuda
.
get_device_name
()
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment