Unverified Commit 53aa9194 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix warpers on CPU (#472)

Closes #471
parent ece7ffa4
...@@ -237,20 +237,12 @@ def get_model( ...@@ -237,20 +237,12 @@ def get_model(
) )
elif model_type == "t5": elif model_type == "t5":
if sharded: return T5Sharded(
return T5Sharded( model_id,
model_id, revision,
revision, quantize=quantize,
quantize=quantize, trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code, )
)
else:
return Seq2SeqLM(
model_id,
revision,
quantize=quantize,
trust_remote_code=trust_remote_code,
)
if sharded: if sharded:
raise ValueError("sharded is not supported for AutoModel") raise ValueError("sharded is not supported for AutoModel")
......
...@@ -42,25 +42,31 @@ class StaticWarper: ...@@ -42,25 +42,31 @@ class StaticWarper:
self.static_next_logprob = None self.static_next_logprob = None
def __call__(self, scores): def __call__(self, scores):
if self.cuda_graph is None: if torch.cuda.is_available():
self.static_scores = scores if self.cuda_graph is None:
self.cuda_graph = torch.cuda.CUDAGraph() self.static_scores = scores
self.cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.cuda_graph, pool=mempool):
local_scores = self.static_scores with torch.cuda.graph(self.cuda_graph, pool=mempool):
for warper in self.warpers: local_scores = self.static_scores
local_scores = warper(None, local_scores) for warper in self.warpers:
local_scores = warper(None, local_scores)
self.static_warped_scores = local_scores
# Compute logprobs self.static_warped_scores = local_scores
self.static_next_logprob = torch.log_softmax( # Compute logprobs
self.static_warped_scores, -1 self.static_next_logprob = torch.log_softmax(
) self.static_warped_scores, -1
)
self.static_scores.copy_(scores)
self.cuda_graph.replay() self.static_scores.copy_(scores)
self.cuda_graph.replay()
return self.static_warped_scores, self.static_next_logprob
return self.static_warped_scores, self.static_next_logprob
# CPU branch
for warper in self.warpers:
scores = warper(None, scores)
return scores, torch.log_softmax(scores, -1)
@lru_cache(10) @lru_cache(10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment