Unverified Commit a1f2dc90 authored by Xuchun Shang's avatar Xuchun Shang Committed by GitHub
Browse files

[Bug fix] [PP] fix wrong dtype for quantified model (#12247)


Signed-off-by: default avatarXuchun Shang <xuchun.shang@gmail.com>
parent ea961060
...@@ -323,11 +323,11 @@ class CudaGraphRunner: ...@@ -323,11 +323,11 @@ class CudaGraphRunner:
self.pp_proxy_tensors = { self.pp_proxy_tensors = {
"hidden_states": torch.zeros( "hidden_states": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size), (self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16, dtype=self.model_runner.model_config.dtype,
), ),
"residual": torch.zeros( "residual": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size), (self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16, dtype=self.model_runner.model_config.dtype,
), ),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment