Unverified Commit 107f5fc4 authored by Reid's avatar Reid Committed by GitHub
Browse files

[Misc] refactor disaggregated-prefill-v1 example (#18474)


Signed-off-by: default avatarreidliu41 <reid201711@gmail.com>
Co-authored-by: default avatarreidliu41 <reid201711@gmail.com>
parent 907f935d
...@@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl ...@@ -5,5 +5,6 @@ This example contains scripts that demonstrate disaggregated prefill in the offl
## Files ## Files
- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. - `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially.
- Make sure you are in the `examples/offline_inference/disaggregated-prefill-v1` directory before running `run.sh`.
- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. - `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`.
- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. - `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`.
...@@ -3,35 +3,47 @@ ...@@ -3,35 +3,47 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
# Read prompts from output.txt
prompts = [] def read_prompts():
try: """Read prompts from output.txt"""
with open("output.txt") as f: prompts = []
for line in f: try:
prompts.append(line.strip()) with open("output.txt") as f:
print(f"Loaded {len(prompts)} prompts from output.txt") for line in f:
except FileNotFoundError: prompts.append(line.strip())
print("Error: output.txt file not found") print(f"Loaded {len(prompts)} prompts from output.txt")
exit(-1) return prompts
except FileNotFoundError:
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) print("Error: output.txt file not found")
exit(-1)
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True,
gpu_memory_utilization=0.8, def main():
max_num_batched_tokens=64, prompts = read_prompts()
max_num_seqs=16, sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10)
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
kv_role="kv_both", enforce_eager=True,
kv_connector_extra_config={ gpu_memory_utilization=0.8,
"shared_storage_path": "local_storage" max_num_batched_tokens=64,
})) #, max_model_len=2048, max_num_batched_tokens=2048) max_num_seqs=16,
kv_transfer_config=KVTransferConfig(
# 1ST generation (prefill instance) kv_connector="SharedStorageConnector",
outputs = llm.generate(prompts, sampling_params) kv_role="kv_both",
kv_connector_extra_config={
for output in outputs: "shared_storage_path": "local_storage"
prompt = output.prompt })) #, max_model_len=2048, max_num_batched_tokens=2048)
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # 1ST generation (prefill instance)
outputs = llm.generate(prompts, sampling_params)
print("-" * 30)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
if __name__ == "__main__":
main()
...@@ -3,42 +3,54 @@ ...@@ -3,42 +3,54 @@
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
context = "Hi " * 1000
context2 = "Hey " * 500 def read_prompts():
prompts = [ context = "Hi " * 1000
context + "Hello, my name is", context2 = "Hey " * 500
context + "The capital of France is", return [
context2 + "Your name is", context + "Hello, my name is",
context2 + "The capital of China is", context + "The capital of France is",
] context2 + "Your name is",
context2 + "The capital of China is",
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) ]
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
enforce_eager=True, def main():
gpu_memory_utilization=0.8, prompts = read_prompts()
kv_transfer_config=KVTransferConfig(
kv_connector="SharedStorageConnector", sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
kv_role="kv_both",
kv_connector_extra_config={ llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct",
"shared_storage_path": "local_storage" enforce_eager=True,
})) #, max_model_len=2048, max_num_batched_tokens=2048) gpu_memory_utilization=0.8,
kv_transfer_config=KVTransferConfig(
# 1ST generation (prefill instance) kv_connector="SharedStorageConnector",
outputs = llm.generate( kv_role="kv_both",
prompts, kv_connector_extra_config={
sampling_params, "shared_storage_path": "local_storage"
) })) #, max_model_len=2048, max_num_batched_tokens=2048)
new_prompts = [] # 1ST generation (prefill instance)
for output in outputs: outputs = llm.generate(
prompt = output.prompt prompts,
generated_text = output.outputs[0].text sampling_params,
new_prompts.append(prompt + generated_text) )
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
new_prompts = []
# Write new_prompts to output.txt print("-" * 30)
with open("output.txt", "w") as f: for output in outputs:
for prompt in new_prompts: prompt = output.prompt
f.write(prompt + "\n") generated_text = output.outputs[0].text
print(f"Saved {len(new_prompts)} prompts to output.txt") new_prompts.append(prompt + generated_text)
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 30)
# Write new_prompts to output.txt
with open("output.txt", "w") as f:
for prompt in new_prompts:
f.write(prompt + "\n")
print(f"Saved {len(new_prompts)} prompts to output.txt")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment