"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "49428eb7efbe9e5f7d9b9f30bb5fb02d7c2f8fb4"
Commit fcab93b4 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Minor tweak to allocate_inference_cache

parent ba2fe7f3
...@@ -158,6 +158,9 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -158,6 +158,9 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
class GenerationMixin: class GenerationMixin:
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
raise NotImplementedError
def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0,
return_dict_in_generate=False, output_scores=False, **kwargs): return_dict_in_generate=False, output_scores=False, **kwargs):
output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p,
...@@ -224,11 +227,11 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p ...@@ -224,11 +227,11 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
gc.collect() gc.collect()
cache.device, cache.dtype = device, dtype cache.device, cache.dtype = device, dtype
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
if hasattr(model, 'allocate_inference_cache'): if hasattr(model, 'allocate_inference_cache'):
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
else: else:
headdim = getattr(model.config, 'head_dim',
model.config.hidden_size // model.config.num_attention_heads)
inf_cache = allocate_inference_cache( inf_cache = allocate_inference_cache(
batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim,
model.config.num_hidden_layers, device, dtype model.config.num_hidden_layers, device, dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment