Commit 57adffa2 authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen2 and mixtral layout

parent 184b50f7
...@@ -375,15 +375,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -375,15 +375,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -325,6 +325,17 @@ class Qwen2Model(nn.Module): ...@@ -325,6 +325,17 @@ class Qwen2Model(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '0'))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
......
...@@ -22,10 +22,10 @@ except ImportError as e: ...@@ -22,10 +22,10 @@ except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
# import custom ops, trigger op registration # import custom ops, trigger op registration
try: # try:
import vllm._rocm_C # noqa: F401 # import vllm._rocm_C # noqa: F401
except ImportError as e: # except ImportError as e:
logger.warning("Failed to import from vllm._rocm_C with %r", e) # logger.warning("Failed to import from vllm._rocm_C with %r", e)
if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
# logger.warning("`fork` method is not supported by ROCm. " # logger.warning("`fork` method is not supported by ROCm. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment