Unverified Commit 11f3cca6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix select (#64)

parent ca13f3b8
...@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module): ...@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0] # assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
logprobs = torch.zeros( logprobs = torch.zeros(
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" (all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
) )
...@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): ...@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
end = torch.cumsum(len_add_1.sub_(1), dim=0) end = torch.cumsum(len_add_1.sub_(1), dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0) start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1) end.sub_(1)
torch.cuda.synchronize()
sum_logp = cumsum[end] - cumsum[start] + logprobs[start] sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
res = sum_logp / len_add_1 res = sum_logp / len_add_1
return res return res
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import List from typing import List
import logging
import numpy as np import numpy as np
import torch import torch
...@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig ...@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
logger = logging.getLogger("model_runner")
# for model_mode # for model_mode
global_model_mode: List[str] = [] global_model_mode: List[str] = []
...@@ -257,6 +262,8 @@ class ModelRunner: ...@@ -257,6 +262,8 @@ class ModelRunner:
if model_class is None: if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}") raise ValueError(f"Unsupported architectures: {architectures}")
logger.info("load weight begin.")
# Load weights # Load weights
linear_method = None linear_method = None
with _set_default_torch_dtype(torch.float16): with _set_default_torch_dtype(torch.float16):
...@@ -267,7 +274,7 @@ class ModelRunner: ...@@ -267,7 +274,7 @@ class ModelRunner:
if hf_quant_config is not None: if hf_quant_config is not None:
# TODO: config quantization awq etc # TODO: config quantization awq etc
quant_config = AWQConfig.from_config(hf_quant_config) quant_config = AWQConfig.from_config(hf_quant_config)
print(f"quant_config: {quant_config}") logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method() linear_method = quant_config.get_linear_method()
model = model_class( model = model_class(
config=self.model_config.hf_config, linear_method=linear_method config=self.model_config.hf_config, linear_method=linear_method
...@@ -280,6 +287,8 @@ class ModelRunner: ...@@ -280,6 +287,8 @@ class ModelRunner:
) )
self.model = model.eval() self.model = model.eval()
logger.info("load weight end.")
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1 self.tp_rank, distributed=self.tp_size > 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment