Unverified Commit 56653520 authored by Charlene Yang's avatar Charlene Yang Committed by GitHub
Browse files

[PyTorch] Make breaking change in `InferenceParams.init` more explicit (#1619)

parent 69365f88
...@@ -128,9 +128,9 @@ class InferenceParams: ...@@ -128,9 +128,9 @@ class InferenceParams:
self, self,
max_batch_size: int, max_batch_size: int,
max_sequence_length: int, max_sequence_length: int,
num_heads_kv: int = 16, num_heads_kv: int = None,
head_dim_k: int = 64, head_dim_k: int = None,
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = None,
head_dim_v: int = None, head_dim_v: int = None,
is_paged: bool = False, is_paged: bool = False,
total_num_pages: int = None, total_num_pages: int = None,
...@@ -141,6 +141,10 @@ class InferenceParams: ...@@ -141,6 +141,10 @@ class InferenceParams:
): ):
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_sequence_length = max_sequence_length self.max_sequence_length = max_sequence_length
assert all(x is not None for x in [num_heads_kv, head_dim_k, dtype]), (
"num_heads_kv, head_dim_k, and dtype are required for InferenceParams since Transformer"
" Engine 2.2."
)
self.num_heads_kv = num_heads_kv self.num_heads_kv = num_heads_kv
self.head_dim_k = head_dim_k self.head_dim_k = head_dim_k
self.dtype = dtype self.dtype = dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment