Unverified Commit c186e816 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[FLAX] Add dtype to embedding for gpt2 model (#18462)

* [FLAX] Add dtype to embedding for gpt2 model

* lint
parent baa00f65
......@@ -597,11 +597,13 @@ class FlaxGPT2Module(nn.Module):
self.config.vocab_size,
self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.wpe = nn.Embed(
self.config.max_position_embeddings,
self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
)
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment