"tests/vscode:/vscode.git/clone" did not exist on "29aae007fffa965c3bcbae7692f1e3d4f301cff8"
Unverified Commit 11f3cca6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix select (#64)

parent ca13f3b8
......@@ -63,7 +63,7 @@ class LogitsProcessor(nn.Module):
def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
# assert all_logprobs.shape[0] == torch.sum(len_add_1) == input_ids.shape[0]
# assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1)
logprobs = torch.zeros(
(all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda"
)
......@@ -72,6 +72,7 @@ def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids):
end = torch.cumsum(len_add_1.sub_(1), dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0)
end.sub_(1)
torch.cuda.synchronize()
sum_logp = cumsum[end] - cumsum[start] + logprobs[start]
res = sum_logp / len_add_1
return res
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
import logging
import numpy as np
import torch
......@@ -12,6 +13,10 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.model_loader import _set_default_torch_dtype
from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
logger = logging.getLogger("model_runner")
# for model_mode
global_model_mode: List[str] = []
......@@ -257,6 +262,8 @@ class ModelRunner:
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")
logger.info("load weight begin.")
# Load weights
linear_method = None
with _set_default_torch_dtype(torch.float16):
......@@ -267,7 +274,7 @@ class ModelRunner:
if hf_quant_config is not None:
# TODO: config quantization awq etc
quant_config = AWQConfig.from_config(hf_quant_config)
print(f"quant_config: {quant_config}")
logger.info(f"quant_config: {quant_config}")
linear_method = quant_config.get_linear_method()
model = model_class(
config=self.model_config.hf_config, linear_method=linear_method
......@@ -280,6 +287,8 @@ class ModelRunner:
)
self.model = model.eval()
logger.info("load weight end.")
def profile_max_num_token(self, total_gpu_memory):
available_gpu_memory = get_available_gpu_memory(
self.tp_rank, distributed=self.tp_size > 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment