Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
93d3fd86
Unverified
Commit
93d3fd86
authored
Mar 17, 2022
by
Suraj Patil
Committed by
GitHub
Mar 17, 2022
Browse files
remove jax.ops.index (#16220)
parent
8481ecef
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
10 deletions
+9
-10
examples/research_projects/jax-projects/model_parallel/README.md
...s/research_projects/jax-projects/model_parallel/README.md
+1
-1
src/transformers/generation_flax_logits_process.py
src/transformers/generation_flax_logits_process.py
+5
-5
src/transformers/models/big_bird/modeling_flax_big_bird.py
src/transformers/models/big_bird/modeling_flax_big_bird.py
+1
-1
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+1
-1
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+1
-2
No files found.
examples/research_projects/jax-projects/model_parallel/README.md
View file @
93d3fd86
...
@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
...
@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
emb
=
jnp
.
zeros
((
50264
,
model
.
config
.
hidden_size
))
emb
=
jnp
.
zeros
((
50264
,
model
.
config
.
hidden_size
))
# update the first 50257 weights using pre-trained weights
# update the first 50257 weights using pre-trained weights
emb
=
emb
.
at
[
jax
.
ops
.
index
[
:
50257
,
:]
]
.
set
(
model
.
params
[
"transformer"
][
"wte"
][
"embedding"
])
emb
=
emb
.
at
[:
50257
,
:].
set
(
model
.
params
[
"transformer"
][
"wte"
][
"embedding"
])
params
=
model
.
params
params
=
model
.
params
params
[
"transformer"
][
"wte"
][
"embedding"
]
=
emb
params
[
"transformer"
][
"wte"
][
"embedding"
]
=
emb
...
...
src/transformers/generation_flax_logits_process.py
View file @
93d3fd86
...
@@ -143,10 +143,10 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
...
@@ -143,10 +143,10 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
# include the token that is higher than top_p as well
# include the token that is higher than top_p as well
score_mask
=
jnp
.
roll
(
score_mask
,
1
)
score_mask
=
jnp
.
roll
(
score_mask
,
1
)
score_mask
|=
score_mask
.
at
[
jax
.
ops
.
index
[
:,
0
]
]
.
set
(
True
)
score_mask
|=
score_mask
.
at
[:,
0
].
set
(
True
)
# min tokens to keep
# min tokens to keep
score_mask
=
score_mask
.
at
[
jax
.
ops
.
index
[
:,
:
self
.
min_tokens_to_keep
]
]
.
set
(
True
)
score_mask
=
score_mask
.
at
[:,
:
self
.
min_tokens_to_keep
].
set
(
True
)
topk_next_scores
=
jnp
.
where
(
score_mask
,
topk_scores
,
mask_scores
)
topk_next_scores
=
jnp
.
where
(
score_mask
,
topk_scores
,
mask_scores
)
next_scores
=
jax
.
lax
.
sort_key_val
(
topk_indices
,
topk_next_scores
)[
-
1
]
next_scores
=
jax
.
lax
.
sort_key_val
(
topk_indices
,
topk_next_scores
)[
-
1
]
...
@@ -207,7 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
...
@@ -207,7 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
apply_penalty
=
1
-
jnp
.
bool_
(
cur_len
-
1
)
apply_penalty
=
1
-
jnp
.
bool_
(
cur_len
-
1
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[
jax
.
ops
.
index
[
:,
self
.
bos_token_id
]
]
.
set
(
0
),
scores
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[:,
self
.
bos_token_id
].
set
(
0
),
scores
)
return
scores
return
scores
...
@@ -232,7 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
...
@@ -232,7 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
apply_penalty
=
1
-
jnp
.
bool_
(
cur_len
-
self
.
max_length
+
1
)
apply_penalty
=
1
-
jnp
.
bool_
(
cur_len
-
self
.
max_length
+
1
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[
jax
.
ops
.
index
[
:,
self
.
eos_token_id
]
]
.
set
(
0
),
scores
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[:,
self
.
eos_token_id
].
set
(
0
),
scores
)
return
scores
return
scores
...
@@ -263,6 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
...
@@ -263,6 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
# create boolean flag to decide if min length penalty should be applied
# create boolean flag to decide if min length penalty should be applied
apply_penalty
=
1
-
jnp
.
clip
(
cur_len
-
self
.
min_length
,
0
,
1
)
apply_penalty
=
1
-
jnp
.
clip
(
cur_len
-
self
.
min_length
,
0
,
1
)
scores
=
jnp
.
where
(
apply_penalty
,
scores
.
at
[
jax
.
ops
.
index
[
:,
self
.
eos_token_id
]
]
.
set
(
-
float
(
"inf"
)),
scores
)
scores
=
jnp
.
where
(
apply_penalty
,
scores
.
at
[:,
self
.
eos_token_id
].
set
(
-
float
(
"inf"
)),
scores
)
return
scores
return
scores
src/transformers/models/big_bird/modeling_flax_big_bird.py
View file @
93d3fd86
...
@@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
...
@@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
token_type_ids
=
(
~
logits_mask
).
astype
(
"i4"
)
token_type_ids
=
(
~
logits_mask
).
astype
(
"i4"
)
logits_mask
=
jnp
.
expand_dims
(
logits_mask
,
axis
=
2
)
logits_mask
=
jnp
.
expand_dims
(
logits_mask
,
axis
=
2
)
logits_mask
=
logits_mask
.
at
[
jax
.
ops
.
index
[
:,
0
]
]
.
set
(
False
)
logits_mask
=
logits_mask
.
at
[:,
0
].
set
(
False
)
# init input tensors if not passed
# init input tensors if not passed
if
token_type_ids
is
None
:
if
token_type_ids
is
None
:
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
93d3fd86
...
@@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
...
@@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
def
_adapt_logits_for_beam_search
(
self
,
logits
):
def
_adapt_logits_for_beam_search
(
self
,
logits
):
"""This function enforces the padding token never to be generated."""
"""This function enforces the padding token never to be generated."""
logits
=
logits
.
at
[
jax
.
ops
.
index
[
:,
:,
self
.
config
.
pad_token_id
]
]
.
set
(
float
(
"-inf"
))
logits
=
logits
.
at
[:,
:,
self
.
config
.
pad_token_id
].
set
(
float
(
"-inf"
))
return
logits
return
logits
def
prepare_inputs_for_generation
(
def
prepare_inputs_for_generation
(
...
...
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
View file @
93d3fd86
...
@@ -964,8 +964,7 @@ class FlaxWav2Vec2Module(nn.Module):
...
@@ -964,8 +964,7 @@ class FlaxWav2Vec2Module(nn.Module):
# these two operations makes sure that all values
# these two operations makes sure that all values
# before the output lengths indices are attended to
# before the output lengths indices are attended to
idx
=
jax
.
ops
.
index
[
jnp
.
arange
(
attention_mask
.
shape
[
0
]),
output_lengths
-
1
]
attention_mask
=
attention_mask
.
at
[
jnp
.
arange
(
attention_mask
.
shape
[
0
]),
output_lengths
-
1
].
set
(
1
)
attention_mask
=
attention_mask
.
at
[
idx
].
set
(
1
)
attention_mask
=
jnp
.
flip
(
jnp
.
flip
(
attention_mask
,
-
1
).
cumsum
(
-
1
),
-
1
).
astype
(
"bool"
)
attention_mask
=
jnp
.
flip
(
jnp
.
flip
(
attention_mask
,
-
1
).
cumsum
(
-
1
),
-
1
).
astype
(
"bool"
)
hidden_states
,
extract_features
=
self
.
feature_projection
(
extract_features
,
deterministic
=
deterministic
)
hidden_states
,
extract_features
=
self
.
feature_projection
(
extract_features
,
deterministic
=
deterministic
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment