"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e5c12c03b711aa2a31b562de3bce92431b4bf662"
Unverified Commit c186e816 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[FLAX] Add dtype to embedding for gpt2 model (#18462)

* [FLAX] Add dtype to embedding for gpt2 model

* lint
parent baa00f65
...@@ -597,11 +597,13 @@ class FlaxGPT2Module(nn.Module): ...@@ -597,11 +597,13 @@ class FlaxGPT2Module(nn.Module):
self.config.vocab_size, self.config.vocab_size,
self.embed_dim, self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.wpe = nn.Embed( self.wpe = nn.Embed(
self.config.max_position_embeddings, self.config.max_position_embeddings,
self.embed_dim, self.embed_dim,
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
dtype=self.dtype,
) )
self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment