Unverified Commit db49ad43 authored by shadeMe's avatar shadeMe
Browse files

Add `device` parameter to `Embedding`

parent 9cac5dd1
......@@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None:
super().__init__(
num_embeddings,
......@@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
device=device
)
GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment