Unverified Commit 257afc37 authored by Harsha vardhan manoj Bikki's avatar Harsha vardhan manoj Bikki Committed by GitHub
Browse files

[Neuron] Adding support for context-lenght, token-gen buckets. (#7885)


Co-authored-by: default avatarHarsha Bikki <harbikh@amazon.com>
parent 86a677de
import os
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
...@@ -19,8 +26,8 @@ llm = LLM( ...@@ -19,8 +26,8 @@ llm = LLM(
# Currently, this is a known limitation in continuous batching support # Currently, this is a known limitation in continuous batching support
# in transformers-neuronx. # in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx. # TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128, max_model_len=2048,
block_size=128, block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed. # The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, # The device argument can be either unspecified for automated detection,
# or explicitly assigned. # or explicitly assigned.
......
"""Utilities for selecting and loading neuron models.""" """Utilities for selecting and loading neuron models."""
import importlib import importlib
import os import os
from typing import Dict, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str: ...@@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f"{list(_NEURON_SUPPORTED_MODELS.keys())}") f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def get_neuron_model(model_config: ModelConfig, def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module: scheduler_config: SchedulerConfig) -> nn.Module:
...@@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig, ...@@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig,
neuron_config = NeuronConfig( neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config) continuous_batching=continuous_batching_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights( model.load_weights(model_config.model,
model_config.model, tp_degree=parallel_config.tensor_parallel_size,
tp_degree=parallel_config.tensor_parallel_size, amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], neuron_config=neuron_config,
neuron_config=neuron_config, context_length_estimate=context_length_estimates,
context_length_estimate=[scheduler_config.max_model_len], n_positions=n_positions,
n_positions=[scheduler_config.max_model_len], batch_size=scheduler_config.max_num_seqs)
batch_size=scheduler_config.max_num_seqs)
return model.eval() return model.eval()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment