Unverified Commit 27bebcd8 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Convert `examples` to `ruff-format` (#18400)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent e7523c2e
...@@ -17,50 +17,55 @@ from vllm.lora.request import LoRARequest ...@@ -17,50 +17,55 @@ from vllm.lora.request import LoRARequest
def create_test_prompts( def create_test_prompts(
lora_path: str lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: ) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
return [ return [
# this is an example of using quantization without LoRA # this is an example of using quantization without LoRA
("My name is", (
SamplingParams(temperature=0.0, "My name is",
logprobs=1, SamplingParams(
prompt_logprobs=1, temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
max_tokens=128), None), ),
None,
),
# the next three examples use quantization with LoRA # the next three examples use quantization with LoRA
("my name is", (
SamplingParams(temperature=0.0, "my name is",
logprobs=1, SamplingParams(
prompt_logprobs=1, temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
max_tokens=128), ),
LoRARequest("lora-test-1", 1, lora_path)), LoRARequest("lora-test-1", 1, lora_path),
("The capital of USA is", ),
SamplingParams(temperature=0.0, (
logprobs=1, "The capital of USA is",
prompt_logprobs=1, SamplingParams(
max_tokens=128), temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
LoRARequest("lora-test-2", 1, lora_path)), ),
("The capital of France is", LoRARequest("lora-test-2", 1, lora_path),
SamplingParams(temperature=0.0, ),
logprobs=1, (
prompt_logprobs=1, "The capital of France is",
max_tokens=128), SamplingParams(
LoRARequest("lora-test-3", 1, lora_path)), temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
),
LoRARequest("lora-test-3", 1, lora_path),
),
] ]
def process_requests(engine: LLMEngine, def process_requests(
test_prompts: list[tuple[str, SamplingParams, engine: LLMEngine,
Optional[LoRARequest]]]): test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0) prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id), engine.add_request(
prompt, str(request_id), prompt, sampling_params, lora_request=lora_request
sampling_params, )
lora_request=lora_request)
request_id += 1 request_id += 1
request_outputs: list[RequestOutput] = engine.step() request_outputs: list[RequestOutput] = engine.step()
...@@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine, ...@@ -71,15 +76,18 @@ def process_requests(engine: LLMEngine,
print(f"Output: {request_output.outputs[0].text}") print(f"Output: {request_output.outputs[0].text}")
def initialize_engine(model: str, quantization: str, def initialize_engine(
lora_repo: Optional[str]) -> LLMEngine: model: str, quantization: str, lora_repo: Optional[str]
) -> LLMEngine:
"""Initialize the LLMEngine.""" """Initialize the LLMEngine."""
engine_args = EngineArgs(model=model, engine_args = EngineArgs(
model=model,
quantization=quantization, quantization=quantization,
enable_lora=True, enable_lora=True,
max_lora_rank=64, max_lora_rank=64,
max_loras=4) max_loras=4,
)
return LLMEngine.from_engine_args(engine_args) return LLMEngine.from_engine_args(engine_args)
...@@ -90,32 +98,30 @@ def main(): ...@@ -90,32 +98,30 @@ def main():
# QLoRA (https://arxiv.org/abs/2305.14314) # QLoRA (https://arxiv.org/abs/2305.14314)
{ {
"name": "qlora_inference_example", "name": "qlora_inference_example",
'model': "huggyllama/llama-7b", "model": "huggyllama/llama-7b",
'quantization': "bitsandbytes", "quantization": "bitsandbytes",
'lora_repo': 'timdettmers/qlora-flan-7b' "lora_repo": "timdettmers/qlora-flan-7b",
}, },
{ {
"name": "AWQ_inference_with_lora_example", "name": "AWQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
'quantization': "awq", "quantization": "awq",
'lora_repo': 'jashing/tinyllama-colorist-lora' "lora_repo": "jashing/tinyllama-colorist-lora",
}, },
{ {
"name": "GPTQ_inference_with_lora_example", "name": "GPTQ_inference_with_lora_example",
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', "model": "TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
'quantization': "gptq", "quantization": "gptq",
'lora_repo': 'jashing/tinyllama-colorist-lora' "lora_repo": "jashing/tinyllama-colorist-lora",
} },
] ]
for test_config in test_configs: for test_config in test_configs:
print( print(f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~")
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~" engine = initialize_engine(
test_config["model"], test_config["quantization"], test_config["lora_repo"]
) )
engine = initialize_engine(test_config['model'], lora_path = snapshot_download(repo_id=test_config["lora_repo"])
test_config['quantization'],
test_config['lora_repo'])
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
test_prompts = create_test_prompts(lora_path) test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts) process_requests(engine, test_prompts)
...@@ -125,5 +131,5 @@ def main(): ...@@ -125,5 +131,5 @@ def main():
torch.cuda.empty_cache() torch.cuda.empty_cache()
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -74,19 +74,10 @@ def run_simple_demo(args: argparse.Namespace): ...@@ -74,19 +74,10 @@ def run_simple_demo(args: argparse.Namespace):
messages = [ messages = [
{ {
"role": "role": "user",
"user",
"content": [ "content": [
{ {"type": "text", "text": prompt},
"type": "text", {"type": "image_url", "image_url": {"url": image_url}},
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
], ],
}, },
] ]
...@@ -121,25 +112,11 @@ def run_advanced_demo(args: argparse.Namespace): ...@@ -121,25 +112,11 @@ def run_advanced_demo(args: argparse.Namespace):
messages = [ messages = [
{ {
"role": "role": "user",
"user",
"content": [ "content": [
{ {"type": "text", "text": prompt},
"type": "text", {"type": "image_url", "image_url": {"url": url_1}},
"text": prompt {"type": "image_url", "image_url": {"url": url_2}},
},
{
"type": "image_url",
"image_url": {
"url": url_1
}
},
{
"type": "image_url",
"image_url": {
"url": url_2
}
},
], ],
}, },
{ {
...@@ -153,12 +130,7 @@ def run_advanced_demo(args: argparse.Namespace): ...@@ -153,12 +130,7 @@ def run_advanced_demo(args: argparse.Namespace):
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {"type": "image_url", "image_url": {"url": url_3}},
"type": "image_url",
"image_url": {
"url": url_3
}
},
], ],
}, },
] ]
...@@ -171,7 +143,8 @@ def run_advanced_demo(args: argparse.Namespace): ...@@ -171,7 +143,8 @@ def run_advanced_demo(args: argparse.Namespace):
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run a demo in simple or advanced mode.") description="Run a demo in simple or advanced mode."
)
parser.add_argument( parser.add_argument(
"mode", "mode",
...@@ -179,15 +152,18 @@ def parse_args(): ...@@ -179,15 +152,18 @@ def parse_args():
help="Specify the demo mode: 'simple' or 'advanced'", help="Specify the demo mode: 'simple' or 'advanced'",
) )
parser.add_argument('--format', parser.add_argument(
"--format",
choices=["mistral", "hf"], choices=["mistral", "hf"],
default="mistral", default="mistral",
help='Specify the format of the model to load.') help="Specify the format of the model to load.",
)
parser.add_argument( parser.add_argument(
'--disable-mm-preprocessor-cache', "--disable-mm-preprocessor-cache",
action='store_true', action="store_true",
help='If True, disables caching of multi-modal preprocessor/mapper.') help="If True, disables caching of multi-modal preprocessor/mapper.",
)
return parser.parse_args() return parser.parse_args()
......
...@@ -13,8 +13,9 @@ import time ...@@ -13,8 +13,9 @@ import time
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
def time_generation(llm: LLM, prompts: list[str], def time_generation(
sampling_params: SamplingParams, title: str): llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str
):
# Generate texts from the prompts. The output is a list of RequestOutput # Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information. # objects that contain the prompt, generated text, and other information.
# Warmup first # Warmup first
...@@ -25,8 +26,7 @@ def time_generation(llm: LLM, prompts: list[str], ...@@ -25,8 +26,7 @@ def time_generation(llm: LLM, prompts: list[str],
end = time.time() end = time.time()
print("-" * 50) print("-" * 50)
print(title) print(title)
print("time: ", print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
(end - start) / sum(len(o.outputs[0].token_ids) for o in outputs))
# Print the outputs. # Print the outputs.
for output in outputs: for output in outputs:
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
...@@ -38,7 +38,8 @@ def main(): ...@@ -38,7 +38,8 @@ def main():
template = ( template = (
"Below is an instruction that describes a task. Write a response " "Below is an instruction that describes a task. Write a response "
"that appropriately completes the request.\n\n### Instruction:\n{}" "that appropriately completes the request.\n\n### Instruction:\n{}"
"\n\n### Response:\n") "\n\n### Response:\n"
)
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
......
...@@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest ...@@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest
def create_test_prompts( def create_test_prompts(
lora_path: str lora_path: str,
) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]: ) -> list[tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters. """Create a list of test prompts with their sampling parameters.
...@@ -26,38 +26,49 @@ def create_test_prompts( ...@@ -26,38 +26,49 @@ def create_test_prompts(
first adapter have finished. first adapter have finished.
""" """
return [ return [
("A robot may not injure a human being", (
SamplingParams(temperature=0.0, "A robot may not injure a human being",
logprobs=1, SamplingParams(
prompt_logprobs=1, temperature=0.0, logprobs=1, prompt_logprobs=1, max_tokens=128
max_tokens=128), None), ),
("To be or not to be,", None,
SamplingParams(temperature=0.8, ),
top_k=5, (
presence_penalty=0.2, "To be or not to be,",
max_tokens=128), None), SamplingParams(
temperature=0.8, top_k=5, presence_penalty=0.2, max_tokens=128
),
None,
),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, SamplingParams(
temperature=0.0,
logprobs=1, logprobs=1,
prompt_logprobs=1, prompt_logprobs=1,
max_tokens=128, max_tokens=128,
stop_token_ids=[32003]), stop_token_ids=[32003],
LoRARequest("sql-lora", 1, lora_path)), ),
LoRARequest("sql-lora", 1, lora_path),
),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, SamplingParams(
temperature=0.0,
logprobs=1, logprobs=1,
prompt_logprobs=1, prompt_logprobs=1,
max_tokens=128, max_tokens=128,
stop_token_ids=[32003]), stop_token_ids=[32003],
LoRARequest("sql-lora2", 2, lora_path)), ),
LoRARequest("sql-lora2", 2, lora_path),
),
] ]
def process_requests(engine: LLMEngine, def process_requests(
test_prompts: list[tuple[str, SamplingParams, engine: LLMEngine,
Optional[LoRARequest]]]): test_prompts: list[tuple[str, SamplingParams, Optional[LoRARequest]]],
):
"""Continuously process a list of prompts and handle the outputs.""" """Continuously process a list of prompts and handle the outputs."""
request_id = 0 request_id = 0
...@@ -65,10 +76,9 @@ def process_requests(engine: LLMEngine, ...@@ -65,10 +76,9 @@ def process_requests(engine: LLMEngine,
while test_prompts or engine.has_unfinished_requests(): while test_prompts or engine.has_unfinished_requests():
if test_prompts: if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0) prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id), engine.add_request(
prompt, str(request_id), prompt, sampling_params, lora_request=lora_request
sampling_params, )
lora_request=lora_request)
request_id += 1 request_id += 1
request_outputs: list[RequestOutput] = engine.step() request_outputs: list[RequestOutput] = engine.step()
...@@ -88,12 +98,14 @@ def initialize_engine() -> LLMEngine: ...@@ -88,12 +98,14 @@ def initialize_engine() -> LLMEngine:
# numbers will cause higher memory usage. If you know that all LoRAs will # numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible. # use the same rank, it is recommended to set this as low as possible.
# max_cpu_loras: controls the size of the CPU LoRA cache. # max_cpu_loras: controls the size of the CPU LoRA cache.
engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", engine_args = EngineArgs(
model="meta-llama/Llama-2-7b-hf",
enable_lora=True, enable_lora=True,
max_loras=1, max_loras=1,
max_lora_rank=8, max_lora_rank=8,
max_cpu_loras=2, max_cpu_loras=2,
max_num_seqs=256) max_num_seqs=256,
)
return LLMEngine.from_engine_args(engine_args) return LLMEngine.from_engine_args(engine_args)
...@@ -105,5 +117,5 @@ def main(): ...@@ -105,5 +117,5 @@ def main():
process_requests(engine, test_prompts) process_requests(engine, test_prompts)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -30,7 +30,8 @@ def main(): ...@@ -30,7 +30,8 @@ def main():
# The device argument can be either unspecified for automated detection, # The device argument can be either unspecified for automated detection,
# or explicitly assigned. # or explicitly assigned.
device="neuron", device="neuron",
tensor_parallel_size=2) tensor_parallel_size=2,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -24,7 +24,7 @@ llm = LLM( ...@@ -24,7 +24,7 @@ llm = LLM(
speculative_config={ speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5, "num_speculative_tokens": 5,
"max_model_len": 2048 "max_model_len": 2048,
}, },
max_num_seqs=4, max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as # The max_model_len and block_size arguments are required to be same as
...@@ -40,7 +40,7 @@ llm = LLM( ...@@ -40,7 +40,7 @@ llm = LLM(
tensor_parallel_size=32, tensor_parallel_size=32,
override_neuron_config={ override_neuron_config={
"enable_eagle_speculation": True, "enable_eagle_speculation": True,
"enable_fused_speculation": True "enable_fused_speculation": True,
}, },
) )
......
...@@ -5,12 +5,12 @@ import os ...@@ -5,12 +5,12 @@ import os
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets. # creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets. # creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
# Quantizes neuron model weight to int8 , # Quantizes neuron model weight to int8 ,
# The default config for quantization is int8 dtype. # The default config for quantization is int8 dtype.
os.environ['NEURON_QUANT_DTYPE'] = "s8" os.environ["NEURON_QUANT_DTYPE"] = "s8"
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
...@@ -44,7 +44,8 @@ def main(): ...@@ -44,7 +44,8 @@ def main():
override_neuron_config={ override_neuron_config={
"cast_logits_dtype": "bfloat16", "cast_logits_dtype": "bfloat16",
}, },
tensor_parallel_size=2) tensor_parallel_size=2,
)
# Generate texts from the prompts. The output is a list of RequestOutput objects # Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
......
...@@ -19,9 +19,9 @@ prompts = [ ...@@ -19,9 +19,9 @@ prompts = [
def config_buckets(): def config_buckets():
"""Configure context length and token gen buckets.""" """Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets. # creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_CONTEXT_LENGTH_BUCKETS"] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets. # creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048"
def initialize_model(): def initialize_model():
...@@ -31,7 +31,7 @@ def initialize_model(): ...@@ -31,7 +31,7 @@ def initialize_model():
speculative_config={ speculative_config={
"model": "openlm-research/open_llama_3b", "model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4, "num_speculative_tokens": 4,
"max_model_len": 2048 "max_model_len": 2048,
}, },
max_num_seqs=4, max_num_seqs=4,
max_model_len=2048, max_model_len=2048,
...@@ -60,5 +60,5 @@ def main(): ...@@ -60,5 +60,5 @@ def main():
process_requests(model, sampling_params) process_requests(model, sampling_params)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -16,7 +16,8 @@ prefix = ( ...@@ -16,7 +16,8 @@ prefix = (
"teaching role. They have 5 years of previous teaching experience " "teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience " "as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill " "in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ") "the following paragraph: "
)
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
...@@ -58,9 +59,11 @@ def main(): ...@@ -58,9 +59,11 @@ def main():
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
# Create an LLM with prefix caching enabled. # Create an LLM with prefix caching enabled.
prefix_cached_llm = LLM(model="facebook/opt-125m", prefix_cached_llm = LLM(
model="facebook/opt-125m",
enable_prefix_caching=True, enable_prefix_caching=True,
gpu_memory_utilization=0.4) gpu_memory_utilization=0.4,
)
# Warmup so that the shared prompt's KV cache is computed. # Warmup so that the shared prompt's KV cache is computed.
prefix_cached_llm.generate(generating_prompts[0], sampling_params) prefix_cached_llm.generate(generating_prompts[0], sampling_params)
...@@ -81,10 +84,12 @@ def main(): ...@@ -81,10 +84,12 @@ def main():
print("-" * 50) print("-" * 50)
# Compare the results and display the speedup # Compare the results and display the speedup
generated_same = all([ generated_same = all(
[
regular_generated_texts[i] == cached_generated_texts[i] regular_generated_texts[i] == cached_generated_texts[i]
for i in range(len(prompts)) for i in range(len(prompts))
]) ]
)
print(f"Generated answers are the same: {generated_same}") print(f"Generated answers are the same: {generated_same}")
......
...@@ -17,6 +17,7 @@ Run the example: ...@@ -17,6 +17,7 @@ Run the example:
python prithvi_geospatial_mae.py python prithvi_geospatial_mae.py
""" # noqa: E501 """ # noqa: E501
import argparse import argparse
import datetime import datetime
import os import os
...@@ -110,77 +111,67 @@ model_config = """{ ...@@ -110,77 +111,67 @@ model_config = """{
# Temporarily creating the "config.json" for the model. # Temporarily creating the "config.json" for the model.
# This is going to disappear once the correct config.json is available on HF # This is going to disappear once the correct config.json is available on HF
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"), with open(
'w') as config_file: os.path.join(os.path.dirname(__file__), "./model/config.json"), "w"
) as config_file:
config_file.write(model_config) config_file.write(model_config)
datamodule_config = { datamodule_config = {
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'], "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
'batch_size': "batch_size": 16,
16, "constant_scale": 0.0001,
'constant_scale': "data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
0.0001, "drop_last": True,
'data_root': "no_data_replace": 0.0,
'/dccstor/geofm-finetuning/datasets/sen1floods11', "no_label_replace": -1,
'drop_last': "num_workers": 8,
True, "test_transform": [
'no_data_replace': albumentations.Resize(
0.0, always_apply=False, height=448, interpolation=1, p=1, width=448
'no_label_replace': ),
-1, albumentations.pytorch.ToTensorV2(
'num_workers': transpose_mask=False, always_apply=True, p=1.0
8, ),
'test_transform': [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0)
], ],
} }
class PrithviMAE: class PrithviMAE:
def __init__(self): def __init__(self):
print("Initializing PrithviMAE model") print("Initializing PrithviMAE model")
self.model = LLM(model=os.path.join(os.path.dirname(__file__), self.model = LLM(
"./model"), model=os.path.join(os.path.dirname(__file__), "./model"),
skip_tokenizer_init=True, skip_tokenizer_init=True,
dtype="float32") dtype="float32",
)
def run(self, input_data, location_coords): def run(self, input_data, location_coords):
print("################ Running inference on vLLM ##############") print("################ Running inference on vLLM ##############")
# merge the inputs into one data structure # merge the inputs into one data structure
mm_data = { mm_data = {
"pixel_values": "pixel_values": torch.empty(0) if input_data is None else input_data,
torch.empty(0) if input_data is None else input_data, "location_coords": torch.empty(0)
"location_coords": if location_coords is None
torch.empty(0) if location_coords is None else location_coords else location_coords,
} }
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False) outputs = self.model.encode(prompt, use_tqdm=False)
print( print("################ Inference done (it took seconds) ##############")
"################ Inference done (it took seconds) ##############"
)
return outputs[0].outputs.data return outputs[0].outputs.data
def generate_datamodule(): def generate_datamodule():
datamodule = Sen1Floods11NonGeoDataModule( datamodule = Sen1Floods11NonGeoDataModule(
data_root=datamodule_config['data_root'], data_root=datamodule_config["data_root"],
batch_size=datamodule_config["batch_size"], batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"], num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"], bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"], drop_last=datamodule_config["drop_last"],
test_transform=datamodule_config["test_transform" test_transform=datamodule_config["test_transform"],
""]) )
return datamodule return datamodule
...@@ -204,8 +195,7 @@ def process_channel_group(orig_img, channels): ...@@ -204,8 +195,7 @@ def process_channel_group(orig_img, channels):
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE)) max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
min_value = OFFSET min_value = OFFSET
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
1)
# No data as zeros # No data as zeros
orig_img[~valid_mask] = 0 orig_img[~valid_mask] = 0
...@@ -300,18 +290,21 @@ def load_example( ...@@ -300,18 +290,21 @@ def load_example(
location_coords.append(coords) location_coords.append(coords)
try: try:
match = re.search(r'(\d{7,8}T\d{6})', file) match = re.search(r"(\d{7,8}T\d{6})", file)
if match: if match:
year = int(match.group(1)[:4]) year = int(match.group(1)[:4])
julian_day = match.group(1).split('T')[0][4:] julian_day = match.group(1).split("T")[0][4:]
if len(julian_day) == 3: if len(julian_day) == 3:
julian_day = int(julian_day) julian_day = int(julian_day)
else: else:
julian_day = datetime.datetime.strptime( julian_day = (
julian_day, '%m%d').timetuple().tm_yday datetime.datetime.strptime(julian_day, "%m%d")
.timetuple()
.tm_yday
)
temporal_coords.append([year, julian_day]) temporal_coords.append([year, julian_day])
except Exception as e: except Exception as e:
print(f'Could not extract timestamp for {file} ({e})') print(f"Could not extract timestamp for {file} ({e})")
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32") imgs = np.moveaxis(imgs, -1, 0).astype("float32")
...@@ -320,50 +313,44 @@ def load_example( ...@@ -320,50 +313,44 @@ def load_example(
return imgs, temporal_coords, location_coords, metas return imgs, temporal_coords, location_coords, metas
def run_model(input_data, def run_model(
input_data,
temporal_coords, temporal_coords,
location_coords, location_coords,
model, model,
datamodule, datamodule,
img_size, img_size,
lightning_model=None): lightning_model=None,
):
# Reflect pad if not divisible by img_size # Reflect pad if not divisible by img_size
original_h, original_w = input_data.shape[-2:] original_h, original_w = input_data.shape[-2:]
pad_h = (img_size - (original_h % img_size)) % img_size pad_h = (img_size - (original_h % img_size)) % img_size
pad_w = (img_size - (original_w % img_size)) % img_size pad_w = (img_size - (original_w % img_size)) % img_size
input_data = np.pad(input_data, input_data = np.pad(
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
mode="reflect") )
# Build sliding window # Build sliding window
batch_size = 1 batch_size = 1
batch = torch.tensor(input_data, device="cpu") batch = torch.tensor(input_data, device="cpu")
windows = (batch.unfold(3, img_size, windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
img_size).unfold(4, img_size, img_size))
h1, w1 = windows.shape[3:5] h1, w1 = windows.shape[3:5]
windows = rearrange(windows, windows = rearrange(
"b c t h1 w1 h w -> (b h1 w1) c t h w", windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
h=img_size, )
w=img_size)
# Split into batches if number of windows > batch_size # Split into batches if number of windows > batch_size
num_batches = windows.shape[0] // batch_size if windows.shape[ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
0] > batch_size else 1
windows = torch.tensor_split(windows, num_batches, dim=0) windows = torch.tensor_split(windows, num_batches, dim=0)
if torch.cuda.is_available(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device('cuda')
else:
device = torch.device('cpu')
if temporal_coords: if temporal_coords:
temporal_coords = torch.tensor(temporal_coords, temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0)
device=device).unsqueeze(0)
else: else:
temporal_coords = None temporal_coords = None
if location_coords: if location_coords:
location_coords = torch.tensor(location_coords[0], location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0)
device=device).unsqueeze(0)
else: else:
location_coords = None location_coords = None
...@@ -371,26 +358,24 @@ def run_model(input_data, ...@@ -371,26 +358,24 @@ def run_model(input_data,
pred_imgs = [] pred_imgs = []
for x in windows: for x in windows:
# Apply standardization # Apply standardization
x = datamodule.test_transform( x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
image=x.squeeze().numpy().transpose(1, 2, 0)) x = datamodule.aug(x)["image"]
x = datamodule.aug(x)['image']
with torch.no_grad(): with torch.no_grad():
x = x.to(device) x = x.to(device)
pred = model.run(x, location_coords=location_coords) pred = model.run(x, location_coords=location_coords)
if lightning_model: if lightning_model:
pred_lightning = lightning_model( pred_lightning = lightning_model(
x, x, temporal_coords=temporal_coords, location_coords=location_coords
temporal_coords=temporal_coords, )
location_coords=location_coords)
pred_lightning = pred_lightning.output.detach().cpu() pred_lightning = pred_lightning.output.detach().cpu()
if not torch.equal(pred, pred_lightning): if not torch.equal(pred, pred_lightning):
print("Inference output is not equal") print("Inference output is not equal")
y_hat = pred.argmax(dim=1) y_hat = pred.argmax(dim=1)
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(), y_hat = torch.nn.functional.interpolate(
size=img_size, y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
mode="nearest") )
pred_imgs.append(y_hat) pred_imgs.append(y_hat)
...@@ -437,8 +422,7 @@ def parse_args(): ...@@ -437,8 +422,7 @@ def parse_args():
default=[1, 2, 3, 8, 11, 12], default=[1, 2, 3, 8, 11, 12],
type=int, type=int,
nargs="+", nargs="+",
help= help="0-based indices of the six Prithvi channels to be selected from the "
"0-based indices of the six Prithvi channels to be selected from the "
"input. By default selects [1,2,3,8,11,12] for S2L1C data.", "input. By default selects [1,2,3,8,11,12] for S2L1C data.",
) )
parser.add_argument( parser.add_argument(
...@@ -478,17 +462,18 @@ def main( ...@@ -478,17 +462,18 @@ def main(
# Running model ------------------------------------------------------------ # Running model ------------------------------------------------------------
channels = [ channels = [
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"] datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
] # BGR -> RGB ] # BGR -> RGB
pred = run_model(input_data, temporal_coords, location_coords, model_obj, pred = run_model(
datamodule, img_size) input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
)
# Save pred # Save pred
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
pred_file = os.path.join( pred_file = os.path.join(
output_dir, output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") )
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data) save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
# Save image + pred # Save image + pred
...@@ -502,13 +487,13 @@ def main( ...@@ -502,13 +487,13 @@ def main(
channels=channels, channels=channels,
) )
pred[pred == 0.] = np.nan pred[pred == 0.0] = np.nan
img_pred = rgb_orig * 0.7 + pred * 0.3 img_pred = rgb_orig * 0.7 + pred * 0.3
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()] img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
img_pred_file = os.path.join( img_pred_file = os.path.join(
output_dir, output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff") )
save_geotiff( save_geotiff(
image=_convert_np_uint8(img_pred), image=_convert_np_uint8(img_pred),
output_path=img_pred_file, output_path=img_pred_file,
...@@ -518,8 +503,9 @@ def main( ...@@ -518,8 +503,9 @@ def main(
# Save image rgb # Save image rgb
if rgb_outputs: if rgb_outputs:
rgb_file = os.path.join( rgb_file = os.path.join(
output_dir, "original_rgb_" output_dir,
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff") f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff",
)
save_geotiff( save_geotiff(
image=_convert_np_uint8(rgb_orig), image=_convert_np_uint8(rgb_orig),
output_path=rgb_file, output_path=rgb_file,
...@@ -528,7 +514,6 @@ def main( ...@@ -528,7 +514,6 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(**vars(args)) main(**vars(args))
...@@ -44,8 +44,11 @@ def get_dtype(dtype: str): ...@@ -44,8 +44,11 @@ def get_dtype(dtype: str):
OutputLen_NumReqs_Map: TypeAlias = dict[int, int] OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
-> OutputLen_NumReqs_Map:
def compute_request_output_lengths(
batch_size: int, step_requests: list[int]
) -> OutputLen_NumReqs_Map:
""" """
Given the number of requests, batch_size, and the number of requests Given the number of requests, batch_size, and the number of requests
that each engine-step should process, step_requests, determine the that each engine-step should process, step_requests, determine the
...@@ -100,17 +103,19 @@ def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \ ...@@ -100,17 +103,19 @@ def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
output_length -= 1 output_length -= 1
# sanity checks. # sanity checks.
assert sum(ol_nr.values()) == batch_size, \ assert sum(ol_nr.values()) == batch_size, (
("Number of requests in output-length assignment does not match " "Number of requests in output-length assignment does not match "
f"batch-size.\n batch size {batch_size} - " f"batch-size.\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}") f"step requests {step_requests} - assignments {ol_nr}"
)
# Check that the output-length is in [1, num-steps]. Output length must be # Check that the output-length is in [1, num-steps]. Output length must be
# at least 1 as all requests must participate in the prefill-step. # at least 1 as all requests must participate in the prefill-step.
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \ assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), (
("Output lengths of requests should be in range " "Output lengths of requests should be in range "
f"[1, num-engine-steps].\n batch size {batch_size} - " f"[1, num-engine-steps].\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}") f"step requests {step_requests} - assignments {ol_nr}"
)
return ol_nr return ol_nr
...@@ -140,10 +145,13 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]: ...@@ -140,10 +145,13 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
# that their output lengths must be equal to num_engine_steps. # that their output lengths must be equal to num_engine_steps.
return [context.batch_size] * context.num_steps return [context.batch_size] * context.num_steps
assert context.complete_num_requests_per_step and \ assert (
context.complete_num_requests_per_step > 0, \ context.complete_num_requests_per_step
(f"Expected a positive complete_num_requests_per_step argument." and context.complete_num_requests_per_step > 0
f"Instead got {context.complete_num_requests_per_step}") ), (
f"Expected a positive complete_num_requests_per_step argument."
f"Instead got {context.complete_num_requests_per_step}"
)
# We start dropping after the first decode step. # We start dropping after the first decode step.
step_requests = [ step_requests = [
...@@ -165,8 +173,9 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]: ...@@ -165,8 +173,9 @@ def determine_requests_per_step(context: ProfileContext) -> list[int]:
return step_requests return step_requests
def run_profile(context: ProfileContext, csv_output: Optional[str], def run_profile(
json_output: Optional[str]): context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]
):
print("Run profile with:") print("Run profile with:")
for key, value in asdict(context).items(): for key, value in asdict(context).items():
print(f" {key} = {value}") print(f" {key} = {value}")
...@@ -174,7 +183,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -174,7 +183,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
requests_per_step: list[int] = determine_requests_per_step(context) requests_per_step: list[int] = determine_requests_per_step(context)
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
context.batch_size, requests_per_step) context.batch_size, requests_per_step
)
num_steps_to_profile: int = len(requests_per_step) num_steps_to_profile: int = len(requests_per_step)
max_output_len: int = max(ol_nr.keys()) max_output_len: int = max(ol_nr.keys())
...@@ -186,7 +196,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -186,7 +196,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
top_p=0.95, top_p=0.95,
# max_tokens is set on a per-request basis. # max_tokens is set on a per-request basis.
max_tokens=None, max_tokens=None,
ignore_eos=True) ignore_eos=True,
)
# Create LLM # Create LLM
llm = LLM(**asdict(context.engine_args)) llm = LLM(**asdict(context.engine_args))
...@@ -199,31 +210,37 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -199,31 +210,37 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
max_num_seqs = scheduler_config.max_num_seqs max_num_seqs = scheduler_config.max_num_seqs
if batch_size * prompt_len > max_num_batched_tokens: if batch_size * prompt_len > max_num_batched_tokens:
print(f"ERROR: chosen batch_size * prompt_len " print(
f"ERROR: chosen batch_size * prompt_len "
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
f"and therefore cannot be run in a single profile step, please " f"and therefore cannot be run in a single profile step, please "
f"choose a smaller batch size or prompt length, or increase " f"choose a smaller batch size or prompt length, or increase "
f"--max-num-batched-tokens") f"--max-num-batched-tokens"
)
sys.exit(-1) sys.exit(-1)
if batch_size > max_num_seqs: if batch_size > max_num_seqs:
print( print(
f"ERROR: chosen batch_size ({batch_size}) is larger than " f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
f"single profile step, please choose a smaller batch size") f"single profile step, please choose a smaller batch size"
)
sys.exit(-1) sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ", print(
llm.llm_engine.model_config.max_model_len) "llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len,
)
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " print(
f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
f"{max_output_len} = {prompt_len + max_output_len}) is larger " f"{max_output_len} = {prompt_len + max_output_len}) is larger "
f"than the model's max_model_len ({max_model_len}), please " f"than the model's max_model_len ({max_model_len}), please "
f"choose a smaller prompt_len or max_output_len, or increase " f"choose a smaller prompt_len or max_output_len, or increase "
f"--max-model-len") f"--max-model-len"
)
sys.exit(-1) sys.exit(-1)
def add_requests(): def add_requests():
def get_output_len_generator() -> Generator[int, Any, Any]: def get_output_len_generator() -> Generator[int, Any, Any]:
for output_len, num_reqs in ol_nr.items(): for output_len, num_reqs in ol_nr.items():
for _ in range(num_reqs): for _ in range(num_reqs):
...@@ -234,13 +251,15 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -234,13 +251,15 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
sampling_params.max_tokens = next(output_len_generator) sampling_params.max_tokens = next(output_len_generator)
assert isinstance(sampling_params.max_tokens, int) assert isinstance(sampling_params.max_tokens, int)
prompt_token_ids = torch.randint(llm.get_tokenizer().vocab_size, prompt_token_ids = torch.randint(
size=(prompt_len, )).tolist() llm.get_tokenizer().vocab_size, size=(prompt_len,)
).tolist()
llm.llm_engine.add_request( llm.llm_engine.add_request(
request_id=f"seq{i}", request_id=f"seq{i}",
prompt={'prompt_token_ids': prompt_token_ids}, prompt={"prompt_token_ids": prompt_token_ids},
params=sampling_params) params=sampling_params,
)
def abort_requests(): def abort_requests():
for i in range(batch_size): for i in range(batch_size):
...@@ -261,10 +280,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -261,10 +280,8 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
decode_profs = [] decode_profs = []
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
num_running_seqs = llm.llm_engine.scheduler[ num_running_seqs = llm.llm_engine.scheduler[0].get_num_unfinished_seq_groups()
0].get_num_unfinished_seq_groups() with layerwise_profile(num_running_seqs=num_running_seqs) as decode_prof:
with layerwise_profile(
num_running_seqs=num_running_seqs) as decode_prof:
llm.llm_engine.step() llm.llm_engine.step()
decode_profs.append(decode_prof) decode_profs.append(decode_prof)
...@@ -274,8 +291,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -274,8 +291,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
LINE_WIDTH = 80 LINE_WIDTH = 80
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= Prefill Model Table " print(f"= Prefill Model Table (prompt_len={prompt_len}, batch_size={batch_size})")
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
prefill_results.print_model_table() prefill_results.print_model_table()
...@@ -283,16 +299,17 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -283,16 +299,17 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
if has_decode: if has_decode:
print() print()
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= First Decode Step Model Table " print(
f"(prompt_len={prompt_len}, batch_size={batch_size})") f"= First Decode Step Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})"
)
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
decode_results_list[0].print_model_table() decode_results_list[0].print_model_table()
print() print()
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= Prefill Summary Table " print(f"= Prefill Summary Table (prompt_len={prompt_len}, batch_size={batch_size})")
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
prefill_results.print_summary_table() prefill_results.print_summary_table()
...@@ -300,25 +317,32 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -300,25 +317,32 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
if has_decode: if has_decode:
print() print()
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print(f"= First Decode Step Summary Table " print(
f"(prompt_len={prompt_len}, batch_size={batch_size})") f"= First Decode Step Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})"
)
print("=" * LINE_WIDTH) print("=" * LINE_WIDTH)
print() print()
decode_results_list[0].print_summary_table() decode_results_list[0].print_summary_table()
if csv_output: if csv_output:
csv_filename_base = csv_output[:-4] \ csv_filename_base = (
if csv_output.endswith('.csv') else csv_output csv_output[:-4] if csv_output.endswith(".csv") else csv_output
)
prefill_results.export_model_stats_table_csv( prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv") csv_filename_base + "_prefill_model_table.csv"
)
prefill_results.export_summary_stats_table_csv( prefill_results.export_summary_stats_table_csv(
csv_filename_base + "_prefill_summary_table.csv") csv_filename_base + "_prefill_summary_table.csv"
)
if has_decode: if has_decode:
decode_results_list[0].export_model_stats_table_csv(\ decode_results_list[0].export_model_stats_table_csv(
csv_filename_base + "_decode_model_table.csv") csv_filename_base + "_decode_model_table.csv"
)
decode_results_list[0].export_summary_stats_table_csv( decode_results_list[0].export_summary_stats_table_csv(
csv_filename_base + "_decode_summary_table.csv") csv_filename_base + "_decode_summary_table.csv"
)
if json_output: if json_output:
cuda_devices = [ cuda_devices = [
...@@ -332,7 +356,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -332,7 +356,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
"torch_version": f"{torch.__version__}", "torch_version": f"{torch.__version__}",
"torch_cuda_version": f"{torch.version.cuda}", "torch_cuda_version": f"{torch.version.cuda}",
"cuda_devices": f"{cuda_devices}", "cuda_devices": f"{cuda_devices}",
**asdict(context) **asdict(context),
}, },
"prefill": prefill_results.convert_stats_to_dict(), "prefill": prefill_results.convert_stats_to_dict(),
} }
...@@ -342,8 +366,9 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -342,8 +366,9 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()
# Add .json to json_output filename if it doesn't exist already. # Add .json to json_output filename if it doesn't exist already.
json_output_file = json_output if json_output.endswith( json_output_file = (
'.json') else json_output + '.json' json_output if json_output.endswith(".json") else json_output + ".json"
)
with open(json_output_file, "w+") as f: with open(json_output_file, "w+") as f:
json.dump(json_dict, f, indent=2) json.dump(json_dict, f, indent=2)
pass pass
...@@ -351,16 +376,21 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], ...@@ -351,16 +376,21 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
if context.save_chrome_traces_folder is not None: if context.save_chrome_traces_folder is not None:
os.makedirs(context.save_chrome_traces_folder, exist_ok=True) os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
prefill_prof.profiler.export_chrome_trace( prefill_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + "/prefill.json") context.save_chrome_traces_folder + "/prefill.json"
)
for idx, decode_prof in enumerate(decode_profs): for idx, decode_prof in enumerate(decode_profs):
decode_prof.profiler.export_chrome_trace( decode_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json") context.save_chrome_traces_folder + f"/decode_{idx + 1}.json"
print("Traces saved as prefill.json and decode_1.json, etc." )
f" in folder {context.save_chrome_traces_folder}") print(
"Traces saved as prefill.json and decode_1.json, etc."
f" in folder {context.save_chrome_traces_folder}"
)
def parse_args(): def parse_args():
parser = FlexibleArgumentParser(description=""" parser = FlexibleArgumentParser(
description="""
Profile a model Profile a model
example: example:
...@@ -384,7 +414,8 @@ Profile a model ...@@ -384,7 +414,8 @@ Profile a model
--output-directory profile_breakdown --plot-metric pct_cuda_time --output-directory profile_breakdown --plot-metric pct_cuda_time
``` ```
""", """,
formatter_class=RawTextHelpFormatter) formatter_class=RawTextHelpFormatter,
)
parser.add_argument( parser.add_argument(
"--csv", "--csv",
type=str, type=str,
...@@ -393,59 +424,68 @@ Profile a model ...@@ -393,59 +424,68 @@ Profile a model
"filename, will create <filename>_prefill_model_table.csv, " "filename, will create <filename>_prefill_model_table.csv, "
"<filename>_prefill_summary_table.csv, " "<filename>_prefill_summary_table.csv, "
"<filename>_decode_model_table.csv, and " "<filename>_decode_model_table.csv, and "
"<filename>_decode_summary_table.csv") "<filename>_decode_summary_table.csv",
)
parser.add_argument( parser.add_argument(
"--json", "--json",
type=str, type=str,
default=None, default=None,
help="Export the results as a json file. This should be the filename") help="Export the results as a json file. This should be the filename",
parser.add_argument("--save-chrome-traces-folder", )
parser.add_argument(
"--save-chrome-traces-folder",
type=str, type=str,
help="Save chrome traces for the prefill and decode " help="Save chrome traces for the prefill and decode "
"will save traces as prefill.json and decode_1.json, " "will save traces as prefill.json and decode_1.json, "
"etc. inside this folder") "etc. inside this folder",
)
parser.add_argument( parser.add_argument(
"--prompt-len", "--prompt-len",
type=int, type=int,
default=PROMPT_LEN_DEFAULT, default=PROMPT_LEN_DEFAULT,
help=f"Length of the random prompt to use when profiling, all batched " help=f"Length of the random prompt to use when profiling, all batched "
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}",
parser.add_argument("--batch-size", )
parser.add_argument(
"--batch-size",
type=int, type=int,
default=BATCH_SIZE_DEFAULT, default=BATCH_SIZE_DEFAULT,
help=f"Number of requests to run as a single batch, " help=f"Number of requests to run as a single batch, "
f"default={BATCH_SIZE_DEFAULT}") f"default={BATCH_SIZE_DEFAULT}",
)
subparsers = parser.add_subparsers(dest="cmd") subparsers = parser.add_subparsers(dest="cmd")
run_num_steps_parser = subparsers.add_parser( run_num_steps_parser = subparsers.add_parser(
"run_num_steps", "run_num_steps", help="This variation profiles n engine.step() invocations."
help="This variation profiles n engine.step() invocations.") )
run_num_steps_parser.add_argument( run_num_steps_parser.add_argument(
'-n', "-n",
'--num-steps', "--num-steps",
type=int, type=int,
help="Number of engine steps to profile.\n" help="Number of engine steps to profile.\n"
"Setting it to 1, profiles only the prefill step.\n" "Setting it to 1, profiles only the prefill step.\n"
"Setting it to 2, profiles the prefill and first decode step\n" "Setting it to 2, profiles the prefill and first decode step\n"
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
"and so on ...") "and so on ...",
)
run_to_completion_parser = subparsers.add_parser( run_to_completion_parser = subparsers.add_parser(
"run_to_completion", "run_to_completion",
help="This variation profiles all the engine.step() invocations" help="This variation profiles all the engine.step() invocations"
"until the engine exhausts all submitted requests.") "until the engine exhausts all submitted requests.",
)
run_to_completion_parser.add_argument( run_to_completion_parser.add_argument(
'-n', "-n",
'--complete-num-requests-per-step', "--complete-num-requests-per-step",
type=int, type=int,
help= help="Complete complete_num_requests_per_step requests every decode step."
"Complete complete_num_requests_per_step requests every decode step."
"For e.g., with batch_size 128 and complete_num_requests_per_step 32," "For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
"the profiler is run for 6 engine steps, with the steps processing, " "the profiler is run for 6 engine steps, with the steps processing, "
"128, 128, 96, 64, 32, 1 requests respectively.\n" "128, 128, 96, 64, 32, 1 requests respectively.\n"
"Note that we tack-on a one-request step at the end as it is often " "Note that we tack-on a one-request step at the end as it is often "
"useful.") "useful.",
)
EngineArgs.add_cli_args(parser) EngineArgs.add_cli_args(parser)
...@@ -459,7 +499,8 @@ def main(args): ...@@ -459,7 +499,8 @@ def main(args):
k: v k: v
for k, v in vars(args).items() for k, v in vars(args).items()
if k in inspect.signature(ProfileContext).parameters if k in inspect.signature(ProfileContext).parameters
}) },
)
run_profile(context, csv_output=args.csv, json_output=args.json) run_profile(context, csv_output=args.csv, json_output=args.json)
......
...@@ -31,18 +31,16 @@ def main(args: argparse.Namespace): ...@@ -31,18 +31,16 @@ def main(args: argparse.Namespace):
max_tokens=args.output_len, max_tokens=args.output_len,
) )
print(sampling_params) print(sampling_params)
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(
size=(args.batch_size, 10000, size=(args.batch_size, args.input_len)
args.input_len)) )
dummy_prompts: list[PromptType] = [{ dummy_prompts: list[PromptType] = [
"prompt_token_ids": batch {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist()
} for batch in dummy_prompt_token_ids.tolist()] ]
def run_to_completion(): def run_to_completion():
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(dummy_prompts, llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False)
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()
latency = end_time - start_time latency = end_time - start_time
return latency return latency
...@@ -58,10 +56,9 @@ def main(args: argparse.Namespace): ...@@ -58,10 +56,9 @@ def main(args: argparse.Namespace):
profile_dir = args.profile_result_dir profile_dir = args.profile_result_dir
print(f"Profiling (results will be saved to '{profile_dir}')...") print(f"Profiling (results will be saved to '{profile_dir}')...")
# Enable tracing on server # Enable tracing on server
xp.trace_detached("localhost:9012", xp.trace_detached(
profile_dir, "localhost:9012", profile_dir, delay_ms=DELAY_MS, duration_ms=DURATION_MS
delay_ms=DELAY_MS, )
duration_ms=DURATION_MS)
if DELAY_MS == 0: if DELAY_MS == 0:
time.sleep(1.0) time.sleep(1.0)
profile_latencies = [] profile_latencies = []
...@@ -72,30 +69,36 @@ def main(args: argparse.Namespace): ...@@ -72,30 +69,36 @@ def main(args: argparse.Namespace):
return return
if __name__ == '__main__': if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of ' description="Benchmark the latency of processing a single batch of "
'requests till completion.') "requests till completion."
parser.add_argument('--input-len', type=int, default=32) )
parser.add_argument('--output-len', type=int, default=128) parser.add_argument("--input-len", type=int, default=32)
parser.add_argument('--batch-size', type=int, default=8) parser.add_argument("--output-len", type=int, default=128)
parser.add_argument('--num-iters-warmup', parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--num-iters-warmup",
type=int, type=int,
default=5, default=5,
help='Number of iterations to run for warmup.') help="Number of iterations to run for warmup.",
parser.add_argument('--num-iters', )
parser.add_argument(
"--num-iters",
type=int, type=int,
default=1, default=1,
help='Number of iterations to run for profiling.') help="Number of iterations to run for profiling.",
)
parser.add_argument( parser.add_argument(
'--profile-result-dir', "--profile-result-dir",
type=str, type=str,
default="profiles", default="profiles",
help= help=(
('path to save the pytorch profiler output. Can be visualized ' "path to save the pytorch profiler output. Can be visualized "
'with ui.perfetto.dev or Tensorboard ' "with ui.perfetto.dev or Tensorboard "
'(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm).' "(https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)."
)) ),
)
parser = EngineArgs.add_cli_args(parser) parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
...@@ -18,8 +18,7 @@ Run: ...@@ -18,8 +18,7 @@ Run:
""" """
import torch import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
PreTrainedTokenizer)
from vllm import LLM from vllm import LLM
...@@ -32,27 +31,29 @@ def init_tokenizer_and_llm(model_name: str): ...@@ -32,27 +31,29 @@ def init_tokenizer_and_llm(model_name: str):
return tokenizer, embedding_layer, llm return tokenizer, embedding_layer, llm
def get_prompt_embeds(chat: list[dict[str, def get_prompt_embeds(
str]], tokenizer: PreTrainedTokenizer, chat: list[dict[str, str]],
embedding_layer: torch.nn.Module): tokenizer: PreTrainedTokenizer,
token_ids = tokenizer.apply_chat_template(chat, embedding_layer: torch.nn.Module,
add_generation_prompt=True, ):
return_tensors='pt') token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt"
)
prompt_embeds = embedding_layer(token_ids).squeeze(0) prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds return prompt_embeds
def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, def single_prompt_inference(
embedding_layer: torch.nn.Module): llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
chat = [{ ):
"role": "user", chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
"content": "Please tell me about the capital of France."
}]
prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer) prompt_embeds = get_prompt_embeds(chat, tokenizer, embedding_layer)
outputs = llm.generate({ outputs = llm.generate(
{
"prompt_embeds": prompt_embeds, "prompt_embeds": prompt_embeds,
}) }
)
print("\n[Single Inference Output]") print("\n[Single Inference Output]")
print("-" * 30) print("-" * 30)
...@@ -61,34 +62,26 @@ def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, ...@@ -61,34 +62,26 @@ def single_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer,
print("-" * 30) print("-" * 30)
def batch_prompt_inference(llm: LLM, tokenizer: PreTrainedTokenizer, def batch_prompt_inference(
embedding_layer: torch.nn.Module): llm: LLM, tokenizer: PreTrainedTokenizer, embedding_layer: torch.nn.Module
chats = [[{ ):
"role": "user", chats = [
"content": "Please tell me about the capital of France." [{"role": "user", "content": "Please tell me about the capital of France."}],
}], [{"role": "user", "content": "When is the day longest during the year?"}],
[{ [{"role": "user", "content": "Where is bigger, the moon or the sun?"}],
"role": "user", ]
"content": "When is the day longest during the year?"
}],
[{
"role": "user",
"content": "Where is bigger, the moon or the sun?"
}]]
prompt_embeds_list = [ prompt_embeds_list = [
get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats get_prompt_embeds(chat, tokenizer, embedding_layer) for chat in chats
] ]
outputs = llm.generate([{ outputs = llm.generate([{"prompt_embeds": embeds} for embeds in prompt_embeds_list])
"prompt_embeds": embeds
} for embeds in prompt_embeds_list])
print("\n[Batch Inference Outputs]") print("\n[Batch Inference Outputs]")
print("-" * 30) print("-" * 30)
for i, o in enumerate(outputs): for i, o in enumerate(outputs):
print(f"Q{i+1}: {chats[i][0]['content']}") print(f"Q{i + 1}: {chats[i][0]['content']}")
print(f"A{i+1}: {o.outputs[0].text}\n") print(f"A{i + 1}: {o.outputs[0].text}\n")
print("-" * 30) print("-" * 30)
......
...@@ -27,51 +27,55 @@ class QueryResult(NamedTuple): ...@@ -27,51 +27,55 @@ class QueryResult(NamedTuple):
default_system = ( default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba " "You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as " "Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.") "generating text and speech."
)
def get_mixed_modalities_query() -> QueryResult: def get_mixed_modalities_query() -> QueryResult:
question = ("What is recited in the audio? " question = (
"What is the content of this image? Why is this video funny?") "What is recited in the audio? "
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" "What is the content of this image? Why is this video funny?"
)
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>" "<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>" "<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n" f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n") f"<|im_start|>assistant\n"
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
"multi_modal_data": { "multi_modal_data": {
"audio": "audio": AudioAsset("mary_had_lamb").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate, "image": convert_image_mode(
"image": ImageAsset("cherry_blossom").pil_image, "RGB"
convert_image_mode( ),
ImageAsset("cherry_blossom").pil_image, "RGB"), "video": VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
"video":
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
}, },
limit_mm_per_prompt={
"audio": 1,
"image": 1,
"video": 1
}, },
limit_mm_per_prompt={"audio": 1, "image": 1, "video": 1},
) )
def get_use_audio_in_video_query() -> QueryResult: def get_use_audio_in_video_query() -> QueryResult:
question = ("Describe the content of the video, " question = (
"then convert what the baby say into text.") "Describe the content of the video, then convert what the baby say into text."
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" )
prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n" f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n") f"<|im_start|>assistant\n"
)
asset = VideoAsset(name="baby_reading", num_frames=16) asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000) audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " assert not envs.VLLM_USE_V1, (
"V1 does not support use_audio_in_video. "
"Please launch this example with " "Please launch this example with "
"`VLLM_USE_V1=0`.") "`VLLM_USE_V1=0`."
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
...@@ -83,20 +87,19 @@ def get_use_audio_in_video_query() -> QueryResult: ...@@ -83,20 +87,19 @@ def get_use_audio_in_video_query() -> QueryResult:
"use_audio_in_video": True, "use_audio_in_video": True,
}, },
}, },
limit_mm_per_prompt={ limit_mm_per_prompt={"audio": 1, "video": 1},
"audio": 1,
"video": 1
},
) )
def get_multi_audios_query() -> QueryResult: def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?" question = "Are these two audio clips the same?"
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n" prompt = (
f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|audio_bos|><|AUDIO|><|audio_eos|>" "<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n" f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n") f"<|im_start|>assistant\n"
)
return QueryResult( return QueryResult(
inputs={ inputs={
"prompt": prompt, "prompt": prompt,
...@@ -124,18 +127,19 @@ def main(args): ...@@ -124,18 +127,19 @@ def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B" model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]() query_result = query_map[args.query_type]()
llm = LLM(model=model_name, llm = LLM(
model=model_name,
max_model_len=5632, max_model_len=5632,
max_num_seqs=5, max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt, limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed) seed=args.seed,
)
# We set temperature to 0.2 so that outputs can be different # We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference. # even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64) sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
outputs = llm.generate(query_result.inputs, outputs = llm.generate(query_result.inputs, sampling_params=sampling_params)
sampling_params=sampling_params)
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
...@@ -144,18 +148,23 @@ def main(args): ...@@ -144,18 +148,23 @@ def main(args):
def parse_args(): def parse_args():
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with ' description="Demo on using vLLM for offline inference with "
'audio language models') "audio language models"
parser.add_argument('--query-type', )
'-q', parser.add_argument(
"--query-type",
"-q",
type=str, type=str,
default="mixed_modalities", default="mixed_modalities",
choices=query_map.keys(), choices=query_map.keys(),
help='Query type.') help="Query type.",
parser.add_argument("--seed", )
parser.add_argument(
"--seed",
type=int, type=int,
default=None, default=None,
help="Set the seed when initializing `vllm.LLM`.") help="Set the seed when initializing `vllm.LLM`.",
)
return parser.parse_args() return parser.parse_args()
......
...@@ -17,10 +17,10 @@ def load_prompt() -> str: ...@@ -17,10 +17,10 @@ def load_prompt() -> str:
# https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt
with urlopen( with urlopen(
"https://qianwen-res.oss-cn-beijing.aliyuncs.com" "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt",
"/Qwen2.5-1M/test-data/600k.txt", timeout=5,
timeout=5) as response: ) as response:
prompt = response.read().decode('utf-8') prompt = response.read().decode("utf-8")
return prompt return prompt
...@@ -41,18 +41,22 @@ def process_requests(llm: LLM, prompts: list[str]) -> None: ...@@ -41,18 +41,22 @@ def process_requests(llm: LLM, prompts: list[str]) -> None:
for output in outputs: for output in outputs:
prompt_token_ids = output.prompt_token_ids prompt_token_ids = output.prompt_token_ids
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt length: {len(prompt_token_ids)}, " print(
f"Generated text: {generated_text!r}") f"Prompt length: {len(prompt_token_ids)}, "
f"Generated text: {generated_text!r}"
)
# Create an LLM. # Create an LLM.
def initialize_engine() -> LLM: def initialize_engine() -> LLM:
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M", llm = LLM(
model="Qwen/Qwen2.5-7B-Instruct-1M",
max_model_len=1048576, max_model_len=1048576,
tensor_parallel_size=4, tensor_parallel_size=4,
enforce_eager=True, enforce_eager=True,
enable_chunked_prefill=True, enable_chunked_prefill=True,
max_num_batched_tokens=131072) max_num_batched_tokens=131072,
)
return llm return llm
...@@ -62,5 +66,5 @@ def main(): ...@@ -62,5 +66,5 @@ def main():
process_requests(llm, [prompt]) process_requests(llm, [prompt])
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -12,6 +12,7 @@ inference instance. In practice, there could be multiple training instances ...@@ -12,6 +12,7 @@ inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework. to the OpenRLHF framework.
""" """
import os import os
import ray import ray
...@@ -26,7 +27,6 @@ from vllm.utils import get_ip, get_open_port ...@@ -26,7 +27,6 @@ from vllm.utils import get_ip, get_open_port
class MyLLM(LLM): class MyLLM(LLM):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# a hack to make the script work. # a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES # stop ray from manipulating CUDA_VISIBLE_DEVICES
...@@ -89,8 +89,7 @@ print("-" * 50) ...@@ -89,8 +89,7 @@ print("-" * 50)
for output in outputs: for output in outputs:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
f"Generated text: {generated_text!r}")
print("-" * 50) print("-" * 50)
# set up the communication between the training process # set up the communication between the training process
...@@ -98,11 +97,13 @@ for output in outputs: ...@@ -98,11 +97,13 @@ for output in outputs:
master_address = get_ip() master_address = get_ip()
master_port = get_open_port() master_port = get_open_port()
handle = llm.collective_rpc.remote("init_weight_update_group", handle = llm.collective_rpc.remote(
args=(master_address, master_port, 1, 3)) "init_weight_update_group", args=(master_address, master_port, 1, 3)
)
model_update_group = stateless_init_process_group(master_address, master_port, model_update_group = stateless_init_process_group(
0, 3, torch.device("cuda:0")) master_address, master_port, 0, 3, torch.device("cuda:0")
)
ray.get(handle) ray.get(handle)
# simulate training, modify the weights of the model. # simulate training, modify the weights of the model.
...@@ -111,8 +112,7 @@ for name, p in train_model.named_parameters(): ...@@ -111,8 +112,7 @@ for name, p in train_model.named_parameters():
# sync weight from the training process to the inference engine. # sync weight from the training process to the inference engine.
for name, p in train_model.named_parameters(): for name, p in train_model.named_parameters():
handle = llm.collective_rpc.remote("update_weight", handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
args=(name, p.dtype, p.shape))
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
ray.get(handle) ray.get(handle)
...@@ -126,6 +126,5 @@ print("-" * 50) ...@@ -126,6 +126,5 @@ print("-" * 50)
for output in outputs_updated: for output in outputs_updated:
prompt = output.prompt prompt = output.prompt
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n" print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
f"Generated text: {generated_text!r}")
print("-" * 50) print("-" * 50)
...@@ -9,6 +9,7 @@ The key points: ...@@ -9,6 +9,7 @@ The key points:
- Use cuda-ipc to pass tensors, since NCCL does not work when we have - Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU. multiple processes on the same GPU.
""" """
import os import os
import ray import ray
...@@ -20,7 +21,6 @@ from vllm import LLM ...@@ -20,7 +21,6 @@ from vllm import LLM
class MyLLM(LLM): class MyLLM(LLM):
def __init__(self, *args, bundle_indices: list, **kwargs): def __init__(self, *args, bundle_indices: list, **kwargs):
# a hack to make the script work. # a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES # stop ray from manipulating CUDA_VISIBLE_DEVICES
...@@ -29,17 +29,16 @@ class MyLLM(LLM): ...@@ -29,17 +29,16 @@ class MyLLM(LLM):
# every worker will use 0.4 GPU, so that we can schedule # every worker will use 0.4 GPU, so that we can schedule
# 2 instances on the same GPUs. # 2 instances on the same GPUs.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join( os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
map(str, bundle_indices))
print(f"creating LLM with bundle_indices={bundle_indices}") print(f"creating LLM with bundle_indices={bundle_indices}")
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
class RayTrainingActor: class RayTrainingActor:
def __init__(self): def __init__(self):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0") self.model.to("cuda:0")
for name, p in self.model.named_parameters(): for name, p in self.model.named_parameters():
...@@ -48,6 +47,7 @@ class RayTrainingActor: ...@@ -48,6 +47,7 @@ class RayTrainingActor:
# the argument for get_device_uuid is the index # the argument for get_device_uuid is the index
# of the GPU in the visible devices. # of the GPU in the visible devices.
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(0) self.device_uuid = current_platform.get_device_uuid(0)
def report_device_id(self) -> str: def report_device_id(self) -> str:
...@@ -55,6 +55,7 @@ class RayTrainingActor: ...@@ -55,6 +55,7 @@ class RayTrainingActor:
def get_weight_ipc_handles(self): def get_weight_ipc_handles(self):
from torch.multiprocessing.reductions import reduce_tensor from torch.multiprocessing.reductions import reduce_tensor
data = {} data = {}
for name, p in self.model.named_parameters(): for name, p in self.model.named_parameters():
# the training actor might only have a subset of the weights # the training actor might only have a subset of the weights
...@@ -101,7 +102,7 @@ for bundle_index, training_actor in enumerate(training_actors): ...@@ -101,7 +102,7 @@ for bundle_index, training_actor in enumerate(training_actors):
print(f"training actor {bundle_index} is on {device_id}") print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id) training_actor_device_ids.append(device_id)
for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# IMPORTANT: when creating vLLM instances, we need to # IMPORTANT: when creating vLLM instances, we need to
# make sure there are no GPU activities on the target GPUs, # make sure there are no GPU activities on the target GPUs,
# otherwise, they will interfere with the vLLM memory profiling, # otherwise, they will interfere with the vLLM memory profiling,
...@@ -128,7 +129,8 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]): ...@@ -128,7 +129,8 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
for i, llm in enumerate(inference_engines): for i, llm in enumerate(inference_engines):
inference_engine_device_ids.append( inference_engine_device_ids.append(
ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))) ray.get(llm.collective_rpc.remote("report_device_id", args=tuple()))
)
print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") print(f"inference engine {i} is on {inference_engine_device_ids[-1]}")
# check the placement # check the placement
...@@ -147,9 +149,10 @@ for actor in training_actors: ...@@ -147,9 +149,10 @@ for actor in training_actors:
print("update the weights of the inference engines") print("update the weights of the inference engines")
for llm in inference_engines: for llm in inference_engines:
ray.get( ray.get(
llm.collective_rpc.remote("update_weights_from_ipc_handles", llm.collective_rpc.remote(
args=(ipc_handles, ))) "update_weights_from_ipc_handles", args=(ipc_handles,)
)
)
print("check if the weights are updated") print("check if the weights are updated")
for llm in inference_engines: for llm in inference_engines:
assert ray.get( assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple()))
llm.collective_rpc.remote("check_weights_changed", args=tuple()))
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
import torch import torch
def stateless_init_process_group(master_address, master_port, rank, world_size, def stateless_init_process_group(master_address, master_port, rank, world_size, device):
device):
""" """
vLLM provides `StatelessProcessGroup` to create a process group vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed. without considering the global process group in torch.distributed.
...@@ -13,10 +12,10 @@ def stateless_init_process_group(master_address, master_port, rank, world_size, ...@@ -13,10 +12,10 @@ def stateless_init_process_group(master_address, master_port, rank, world_size,
""" """
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup from vllm.distributed.utils import StatelessProcessGroup
pg = StatelessProcessGroup.create(host=master_address,
port=master_port, pg = StatelessProcessGroup.create(
rank=rank, host=master_address, port=master_port, rank=rank, world_size=world_size
world_size=world_size) )
pynccl = PyNcclCommunicator(pg, device=device) pynccl = PyNcclCommunicator(pg, device=device)
return pynccl return pynccl
...@@ -31,9 +30,11 @@ class WorkerExtension: ...@@ -31,9 +30,11 @@ class WorkerExtension:
should pass the full qualified name as `worker_extension_cls` argument. should pass the full qualified name as `worker_extension_cls` argument.
""" """
def init_weight_update_group(self, master_address, master_port, def init_weight_update_group(
rank_offset, world_size): self, master_address, master_port, rank_offset, world_size
):
from vllm.distributed.parallel_state import get_world_group from vllm.distributed.parallel_state import get_world_group
rank = get_world_group().rank + rank_offset rank = get_world_group().rank + rank_offset
self.model_update_group = stateless_init_process_group( self.model_update_group = stateless_init_process_group(
master_address, master_address,
...@@ -45,9 +46,9 @@ class WorkerExtension: ...@@ -45,9 +46,9 @@ class WorkerExtension:
def update_weight(self, name, dtype, shape): def update_weight(self, name, dtype, shape):
weight = torch.empty(shape, dtype=dtype, device="cuda") weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(weight, self.model_update_group.broadcast(
src=0, weight, src=0, stream=torch.cuda.current_stream()
stream=torch.cuda.current_stream()) )
self.model_runner.model.load_weights(weights=[(name, weight)]) self.model_runner.model.load_weights(weights=[(name, weight)])
...@@ -59,8 +60,7 @@ class WorkerExtension: ...@@ -59,8 +60,7 @@ class WorkerExtension:
""" """
weights_updated = True weights_updated = True
for name, p in self.model_runner.model.named_parameters(): for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose( weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
p, torch.zeros_like(p))
return weights_updated return weights_updated
...@@ -76,6 +76,7 @@ class ColocateWorkerExtension: ...@@ -76,6 +76,7 @@ class ColocateWorkerExtension:
def report_device_id(self) -> str: def report_device_id(self) -> str:
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.device_uuid = current_platform.get_device_uuid(self.device.index) self.device_uuid = current_platform.get_device_uuid(self.device.index)
return self.device_uuid return self.device_uuid
...@@ -100,6 +101,5 @@ class ColocateWorkerExtension: ...@@ -100,6 +101,5 @@ class ColocateWorkerExtension:
""" """
weights_updated = True weights_updated = True
for name, p in self.model_runner.model.named_parameters(): for name, p in self.model_runner.model.named_parameters():
weights_updated = weights_updated and torch.allclose( weights_updated = weights_updated and torch.allclose(p, torch.zeros_like(p))
p, torch.zeros_like(p))
return weights_updated return weights_updated
...@@ -21,6 +21,7 @@ llm = LLM( ...@@ -21,6 +21,7 @@ llm = LLM(
tensor_parallel_size=8, tensor_parallel_size=8,
) )
""" """
import dataclasses import dataclasses
import os import os
import shutil import shutil
...@@ -33,18 +34,18 @@ from vllm.utils import FlexibleArgumentParser ...@@ -33,18 +34,18 @@ from vllm.utils import FlexibleArgumentParser
def parse_args(): def parse_args():
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser) EngineArgs.add_cli_args(parser)
parser.add_argument("--output", parser.add_argument(
"-o", "--output", "-o", required=True, type=str, help="path to output checkpoint"
required=True, )
type=str, parser.add_argument(
help="path to output checkpoint") "--file-pattern", type=str, help="string pattern of saved filenames"
parser.add_argument("--file-pattern", )
type=str, parser.add_argument(
help="string pattern of saved filenames") "--max-file-size",
parser.add_argument("--max-file-size",
type=str, type=str,
default=5 * 1024**3, default=5 * 1024**3,
help="max size (in bytes) of each safetensors file") help="max size (in bytes) of each safetensors file",
)
return parser.parse_args() return parser.parse_args()
...@@ -68,23 +69,23 @@ def main(args): ...@@ -68,23 +69,23 @@ def main(args):
# For V1 engine, we need to use engine_core.save_sharded_state # For V1 engine, we need to use engine_core.save_sharded_state
print("Using V1 engine save path") print("Using V1 engine save path")
llm.llm_engine.engine_core.save_sharded_state( llm.llm_engine.engine_core.save_sharded_state(
path=args.output, path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
pattern=args.file_pattern, )
max_size=args.max_file_size)
else: else:
# For V0 engine # For V0 engine
print("Using V0 engine save path") print("Using V0 engine save path")
model_executor = llm.llm_engine.model_executor model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=args.output, model_executor.save_sharded_state(
pattern=args.file_pattern, path=args.output, pattern=args.file_pattern, max_size=args.max_file_size
max_size=args.max_file_size) )
# Copy metadata files to output directory # Copy metadata files to output directory
for file in os.listdir(model_path): for file in os.listdir(model_path):
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
if os.path.isdir(os.path.join(model_path, file)): if os.path.isdir(os.path.join(model_path, file)):
shutil.copytree(os.path.join(model_path, file), shutil.copytree(
os.path.join(args.output, file)) os.path.join(model_path, file), os.path.join(args.output, file)
)
else: else:
shutil.copy(os.path.join(model_path, file), args.output) shutil.copy(os.path.join(model_path, file), args.output)
......
...@@ -15,20 +15,20 @@ from vllm import LLM, SamplingParams ...@@ -15,20 +15,20 @@ from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
# Guided decoding by Choice (list of possible options) # Guided decoding by Choice (list of possible options)
guided_decoding_params_choice = GuidedDecodingParams( guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
choice=["Positive", "Negative"]) sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
sampling_params_choice = SamplingParams(
guided_decoding=guided_decoding_params_choice)
prompt_choice = "Classify this sentiment: vLLM is wonderful!" prompt_choice = "Classify this sentiment: vLLM is wonderful!"
# Guided decoding by Regex # Guided decoding by Regex
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
sampling_params_regex = SamplingParams( sampling_params_regex = SamplingParams(
guided_decoding=guided_decoding_params_regex, stop=["\n"]) guided_decoding=guided_decoding_params_regex, stop=["\n"]
)
prompt_regex = ( prompt_regex = (
"Generate an email address for Alan Turing, who works in Enigma." "Generate an email address for Alan Turing, who works in Enigma."
"End in .com and new line. Example result:" "End in .com and new line. Example result:"
"alan.turing@enigma.com\n") "alan.turing@enigma.com\n"
)
# Guided decoding by JSON using Pydantic schema # Guided decoding by JSON using Pydantic schema
...@@ -47,10 +47,11 @@ class CarDescription(BaseModel): ...@@ -47,10 +47,11 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
guided_decoding_params_json = GuidedDecodingParams(json=json_schema) guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
sampling_params_json = SamplingParams( sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
guided_decoding=guided_decoding_params_json) prompt_json = (
prompt_json = ("Generate a JSON with the brand, model and car_type of" "Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's") "the most iconic car from the 90's"
)
# Guided decoding by Grammar # Guided decoding by Grammar
simplified_sql_grammar = """ simplified_sql_grammar = """
...@@ -61,12 +62,11 @@ table ::= "table_1 " | "table_2 " ...@@ -61,12 +62,11 @@ table ::= "table_1 " | "table_2 "
condition ::= column "= " number condition ::= column "= " number
number ::= "1 " | "2 " number ::= "1 " | "2 "
""" """
guided_decoding_params_grammar = GuidedDecodingParams( guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
grammar=simplified_sql_grammar) sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
sampling_params_grammar = SamplingParams( prompt_grammar = (
guided_decoding=guided_decoding_params_grammar) "Generate an SQL query to show the 'username' and 'email'from the 'users' table."
prompt_grammar = ("Generate an SQL query to show the 'username' and 'email'" )
"from the 'users' table.")
def format_output(title: str, output: str): def format_output(title: str, output: str):
...@@ -90,8 +90,7 @@ def main(): ...@@ -90,8 +90,7 @@ def main():
json_output = generate_output(prompt_json, sampling_params_json, llm) json_output = generate_output(prompt_json, sampling_params_json, llm)
format_output("Guided decoding by JSON", json_output) format_output("Guided decoding by JSON", json_output)
grammar_output = generate_output(prompt_grammar, sampling_params_grammar, grammar_output = generate_output(prompt_grammar, sampling_params_grammar, llm)
llm)
format_output("Guided decoding by Grammar", grammar_output) format_output("Guided decoding by Grammar", grammar_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment