Unverified Commit 9d353ca7 authored by Tim Dettmers's avatar Tim Dettmers Committed by GitHub
Browse files

Merge pull request #87 from lostmsu/main

Add `device` and `dtype` parameters to `StableEmbedding`
parents 7a6563b6 62d39a23
...@@ -25,6 +25,8 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -25,6 +25,8 @@ class StableEmbedding(torch.nn.Embedding):
scale_grad_by_freq: bool = False, scale_grad_by_freq: bool = False,
sparse: bool = False, sparse: bool = False,
_weight: Optional[Tensor] = None, _weight: Optional[Tensor] = None,
device=None,
dtype=None,
) -> None: ) -> None:
super().__init__( super().__init__(
num_embeddings, num_embeddings,
...@@ -35,8 +37,10 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -35,8 +37,10 @@ class StableEmbedding(torch.nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
device,
dtype,
) )
self.norm = torch.nn.LayerNorm(embedding_dim) self.norm = torch.nn.LayerNorm(embedding_dim, device=device)
GlobalOptimManager.get_instance().register_module_override( GlobalOptimManager.get_instance().register_module_override(
self, "weight", {"optim_bits": 32} self, "weight", {"optim_bits": 32}
) )
...@@ -68,7 +72,10 @@ class StableEmbedding(torch.nn.Embedding): ...@@ -68,7 +72,10 @@ class StableEmbedding(torch.nn.Embedding):
self.sparse, self.sparse,
) )
return self.norm(emb) # always apply layer norm in full precision
emb = emb.to(torch.get_default_dtype())
return self.norm(emb).to(self.weight.dtype)
class Embedding(torch.nn.Embedding): class Embedding(torch.nn.Embedding):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment