Unverified Commit db49ad43 authored by shadeMe's avatar shadeMe
Browse files

Add `device` parameter to `Embedding`

parent 9cac5dd1
...@@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding): ...@@ -92,6 +92,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
device: Optional[device] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
num_embeddings, num_embeddings,
...@@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding): ...@@ -102,6 +103,7 @@ class Embedding(torch.nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
device=device
) )
GlobalOptimManager.get_instance().register_module_override( GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32} self, "weight", {"optim_bits": 32}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment