Unverified Commit ebed81fb authored by aws-elaineyz's avatar aws-elaineyz Committed by GitHub
Browse files

Update default neuron config for speculation (#18274)


Signed-off-by: default avatarElaine Zhao <elaineyz@amazon.com>
Co-authored-by: default avatarShashwat Srijan <sssrijan@amazon.com>
Co-authored-by: default avatarAakash Shetty <sheaak@amazon.com>
parent e2d7d312
......@@ -502,7 +502,7 @@ def _get_default_neuron_config(model_config: ModelConfig,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
enable_bucketing=True,
is_continuous_batching=(batch_size > 1),
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right",
......@@ -520,6 +520,7 @@ def _get_default_speculation_config(model_config: ModelConfig,
args."""
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
......@@ -527,6 +528,7 @@ def _get_default_speculation_config(model_config: ModelConfig,
trace_tokengen_model=False,
enable_fused_speculation=True,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
on_device_sampling_config=dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment