"vscode:/vscode.git/clone" did not exist on "e62091d5a7e953534a74d47888dd3c6eda96f3d4"
Unverified Commit c186e816 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[FLAX] Add dtype to embedding for gpt2 model (#18462)

* [FLAX] Add dtype to embedding for gpt2 model

* lint
parent baa00f65
...@@ -597,11 +597,13 @@ class FlaxGPT2Module(nn.Module): ...@@ -597,11 +597,13 @@ class FlaxGPT2Module(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.embed_dim, self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.wpe = nn.Embed( self.wpe = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.embed_dim, self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment