Unverified Commit c0fa2aa0 authored by James Reed's avatar James Reed Committed by GitHub
Browse files

[HFTracer] Make embeddings ops take on the dtype of the weight (#22347)

* [HFTracer] Make embeddings ops take on the dtype of the weight

* fix bug
parent e8cc0255
......@@ -171,13 +171,13 @@ _SUPPORTED_MODELS = tuple(sorted(set(_REGULAR_SUPPORTED_MODELS + _SPECIAL_SUPPOR
def torch_nn_embedding(self, input):
return torch.empty(*input.shape, self.weight.shape[-1], device="meta")
return torch.empty(*input.shape, self.weight.shape[-1], device="meta", dtype=self.weight.dtype)
def torch_nn_functional_embedding(
input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False
):
return torch.empty(*input.shape, weight.shape[-1], device="meta")
return torch.empty(*input.shape, weight.shape[-1], device="meta", dtype=weight.dtype)
def torch_nn_layernorm(self, input):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment