Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ee92b58b
"template/testdata/vscode:/vscode.git/clone" did not exist on "93a8daf285af45ed71544e79aae0cb15245e75f4"
Unverified
Commit
ee92b58b
authored
Oct 07, 2023
by
Antoni Baum
Committed by
GitHub
Oct 07, 2023
Browse files
Move bfloat16 check to worker (#1259)
parent
09ff7f10
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
9 deletions
+14
-9
vllm/config.py
vllm/config.py
+0
-9
vllm/worker/worker.py
vllm/worker/worker.py
+14
-0
No files found.
vllm/config.py
View file @
ee92b58b
...
...
@@ -345,15 +345,6 @@ def _get_and_verify_dtype(
# Casting between float16 and bfloat16 is allowed with a warning.
logger
.
warning
(
f
"Casting
{
config_dtype
}
to
{
torch_dtype
}
."
)
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
if
compute_capability
[
0
]
<
8
:
gpu_name
=
torch
.
cuda
.
get_device_name
()
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
return
torch_dtype
...
...
vllm/worker/worker.py
View file @
ee92b58b
...
...
@@ -59,6 +59,8 @@ class Worker:
raise
ValueError
(
"Invalid or unspecified rank."
)
torch
.
cuda
.
set_device
(
self
.
device
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
# Initialize the distributed environment.
_init_distributed_environment
(
self
.
parallel_config
,
self
.
rank
,
self
.
distributed_init_method
)
...
...
@@ -385,3 +387,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
f
"(required shared memory
{
required_shared_mem
}
> "
f
"available shared memory
{
max_shared_mem
}
). "
"This will be fixed in a future release."
)
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
compute_capability
=
torch
.
cuda
.
get_device_capability
()
if
compute_capability
[
0
]
<
8
:
gpu_name
=
torch
.
cuda
.
get_device_name
()
raise
ValueError
(
"Bfloat16 is only supported on GPUs with compute capability "
f
"of at least 8.0. Your
{
gpu_name
}
GPU has compute capability "
f
"
{
compute_capability
[
0
]
}
.
{
compute_capability
[
1
]
}
."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment