Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57adffa2
Commit
57adffa2
authored
Jan 03, 2025
by
zhuwenwen
Browse files
update qwen2 and mixtral layout
parent
184b50f7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
8 deletions
+20
-8
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+5
-4
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+11
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+4
-4
No files found.
vllm/model_executor/models/mixtral.py
View file @
57adffa2
...
@@ -375,15 +375,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -375,15 +375,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
self
.
model
.
make_empty_intermediate_tensors
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
self
.
quant_method
=
None
self
.
quant_method
=
None
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_method
=
quant_config
.
get_name
()
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/qwen2.py
View file @
57adffa2
...
@@ -325,6 +325,17 @@ class Qwen2Model(nn.Module):
...
@@ -325,6 +325,17 @@ class Qwen2Model(nn.Module):
else
:
else
:
self
.
norm
=
PPMissingLayer
()
self
.
norm
=
PPMissingLayer
()
self
.
quant_method
=
None
if
quant_config
is
not
None
:
self
.
quant_method
=
quant_config
.
get_name
()
self
.
quant_config
=
quant_config
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
...
vllm/platforms/rocm.py
View file @
57adffa2
...
@@ -22,10 +22,10 @@ except ImportError as e:
...
@@ -22,10 +22,10 @@ except ImportError as e:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
# import custom ops, trigger op registration
# import custom ops, trigger op registration
try
:
#
try:
import
vllm._rocm_C
# noqa: F401
#
import vllm._rocm_C # noqa: F401
except
ImportError
as
e
:
#
except ImportError as e:
logger
.
warning
(
"Failed to import from vllm._rocm_C with %r"
,
e
)
#
logger.warning("Failed to import from vllm._rocm_C with %r", e)
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
None
)
in
[
"fork"
,
None
]:
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
None
)
in
[
"fork"
,
None
]:
# logger.warning("`fork` method is not supported by ROCm. "
# logger.warning("`fork` method is not supported by ROCm. "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment