Unverified Commit 6409fe6c authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

data_parallel -> data_parallel_size

parent de2a60e3
...@@ -49,7 +49,7 @@ class VLLM(LM): ...@@ -49,7 +49,7 @@ class VLLM(LM):
seed: int = 1234, seed: int = 1234,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
device: str = "cuda", device: str = "cuda",
data_parallel: int = 1, data_parallel_size: int = 1,
): ):
super().__init__() super().__init__()
...@@ -63,7 +63,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -63,7 +63,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
assert "cuda" in device or device is None, "vLLM only supports CUDA" assert "cuda" in device or device is None, "vLLM only supports CUDA"
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
self.data_parallel = int(data_parallel) self.data_parallel_size = int(data_parallel_size)
self.model_args = { self.model_args = {
"model": pretrained, "model": pretrained,
"gpu_memory_utilization": float(gpu_memory_utilization), "gpu_memory_utilization": float(gpu_memory_utilization),
...@@ -78,7 +78,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -78,7 +78,7 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
"quantization": quantization, "quantization": quantization,
"seed": int(seed), "seed": int(seed),
} }
if self.data_parallel <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
else: else:
self.model_args["worker_use_ray"] = True self.model_args["worker_use_ray"] = True
...@@ -149,11 +149,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`" ...@@ -149,11 +149,11 @@ please install vllm via `pip install lm-eval[vllm]` or `pip install -e .[vllm]`"
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, prompt_logprobs=2, max_tokens=1 temperature=0, prompt_logprobs=2, max_tokens=1
) )
if self.data_parallel > 1: if self.data_parallel_size > 1:
requests = [list(x) for x in utils.divide(requests, self.data_parallel)] requests = [list(x) for x in utils.divide(requests, self.data_parallel_size)]
inputs = [(self.model_args, sampling_params, req) for req in requests] inputs = [(self.model_args, sampling_params, req) for req in requests]
with Pool(self.data_parallel) as pool: with Pool(self.data_parallel_size) as pool:
results = pool.starmap(run_inference_one_model, inputs) results = pool.starmap(run_inference_one_model, inputs)
# flatten results # flatten results
return [item for sublist in results for item in sublist] return [item for sublist in results for item in sublist]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment