Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
93d3fd86
Unverified
Commit
93d3fd86
authored
Mar 17, 2022
by
Suraj Patil
Committed by
GitHub
Mar 17, 2022
Browse files
remove jax.ops.index (#16220)
parent
8481ecef
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
10 deletions
+9
-10
examples/research_projects/jax-projects/model_parallel/README.md
...s/research_projects/jax-projects/model_parallel/README.md
+1
-1
src/transformers/generation_flax_logits_process.py
src/transformers/generation_flax_logits_process.py
+5
-5
src/transformers/models/big_bird/modeling_flax_big_bird.py
src/transformers/models/big_bird/modeling_flax_big_bird.py
+1
-1
src/transformers/models/marian/modeling_flax_marian.py
src/transformers/models/marian/modeling_flax_marian.py
+1
-1
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+1
-2
No files found.
examples/research_projects/jax-projects/model_parallel/README.md
View file @
93d3fd86
...
...
@@ -35,7 +35,7 @@ model = FlaxGPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
emb
=
jnp
.
zeros
((
50264
,
model
.
config
.
hidden_size
))
# update the first 50257 weights using pre-trained weights
emb
=
emb
.
at
[
jax
.
ops
.
index
[
:
50257
,
:]
]
.
set
(
model
.
params
[
"transformer"
][
"wte"
][
"embedding"
])
emb
=
emb
.
at
[:
50257
,
:].
set
(
model
.
params
[
"transformer"
][
"wte"
][
"embedding"
])
params
=
model
.
params
params
[
"transformer"
][
"wte"
][
"embedding"
]
=
emb
...
...
src/transformers/generation_flax_logits_process.py
View file @
93d3fd86
...
...
@@ -143,10 +143,10 @@ class FlaxTopPLogitsWarper(FlaxLogitsWarper):
# include the token that is higher than top_p as well
score_mask
=
jnp
.
roll
(
score_mask
,
1
)
score_mask
|=
score_mask
.
at
[
jax
.
ops
.
index
[
:,
0
]
]
.
set
(
True
)
score_mask
|=
score_mask
.
at
[:,
0
].
set
(
True
)
# min tokens to keep
score_mask
=
score_mask
.
at
[
jax
.
ops
.
index
[
:,
:
self
.
min_tokens_to_keep
]
]
.
set
(
True
)
score_mask
=
score_mask
.
at
[:,
:
self
.
min_tokens_to_keep
].
set
(
True
)
topk_next_scores
=
jnp
.
where
(
score_mask
,
topk_scores
,
mask_scores
)
next_scores
=
jax
.
lax
.
sort_key_val
(
topk_indices
,
topk_next_scores
)[
-
1
]
...
...
@@ -207,7 +207,7 @@ class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
apply_penalty
=
1
-
jnp
.
bool_
(
cur_len
-
1
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[
jax
.
ops
.
index
[
:,
self
.
bos_token_id
]
]
.
set
(
0
),
scores
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[:,
self
.
bos_token_id
].
set
(
0
),
scores
)
return
scores
...
...
@@ -232,7 +232,7 @@ class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
apply_penalty
=
1
-
jnp
.
bool_
(
cur_len
-
self
.
max_length
+
1
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[
jax
.
ops
.
index
[
:,
self
.
eos_token_id
]
]
.
set
(
0
),
scores
)
scores
=
jnp
.
where
(
apply_penalty
,
new_scores
.
at
[:,
self
.
eos_token_id
].
set
(
0
),
scores
)
return
scores
...
...
@@ -263,6 +263,6 @@ class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
# create boolean flag to decide if min length penalty should be applied
apply_penalty
=
1
-
jnp
.
clip
(
cur_len
-
self
.
min_length
,
0
,
1
)
scores
=
jnp
.
where
(
apply_penalty
,
scores
.
at
[
jax
.
ops
.
index
[
:,
self
.
eos_token_id
]
]
.
set
(
-
float
(
"inf"
)),
scores
)
scores
=
jnp
.
where
(
apply_penalty
,
scores
.
at
[:,
self
.
eos_token_id
].
set
(
-
float
(
"inf"
)),
scores
)
return
scores
src/transformers/models/big_bird/modeling_flax_big_bird.py
View file @
93d3fd86
...
...
@@ -2124,7 +2124,7 @@ class FlaxBigBirdForQuestionAnswering(FlaxBigBirdPreTrainedModel):
if
token_type_ids
is
None
:
token_type_ids
=
(
~
logits_mask
).
astype
(
"i4"
)
logits_mask
=
jnp
.
expand_dims
(
logits_mask
,
axis
=
2
)
logits_mask
=
logits_mask
.
at
[
jax
.
ops
.
index
[
:,
0
]
]
.
set
(
False
)
logits_mask
=
logits_mask
.
at
[:,
0
].
set
(
False
)
# init input tensors if not passed
if
token_type_ids
is
None
:
...
...
src/transformers/models/marian/modeling_flax_marian.py
View file @
93d3fd86
...
...
@@ -1422,7 +1422,7 @@ class FlaxMarianMTModel(FlaxMarianPreTrainedModel):
def
_adapt_logits_for_beam_search
(
self
,
logits
):
"""This function enforces the padding token never to be generated."""
logits
=
logits
.
at
[
jax
.
ops
.
index
[
:,
:,
self
.
config
.
pad_token_id
]
]
.
set
(
float
(
"-inf"
))
logits
=
logits
.
at
[:,
:,
self
.
config
.
pad_token_id
].
set
(
float
(
"-inf"
))
return
logits
def
prepare_inputs_for_generation
(
...
...
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
View file @
93d3fd86
...
...
@@ -964,8 +964,7 @@ class FlaxWav2Vec2Module(nn.Module):
# these two operations makes sure that all values
# before the output lengths indices are attended to
idx
=
jax
.
ops
.
index
[
jnp
.
arange
(
attention_mask
.
shape
[
0
]),
output_lengths
-
1
]
attention_mask
=
attention_mask
.
at
[
idx
].
set
(
1
)
attention_mask
=
attention_mask
.
at
[
jnp
.
arange
(
attention_mask
.
shape
[
0
]),
output_lengths
-
1
].
set
(
1
)
attention_mask
=
jnp
.
flip
(
jnp
.
flip
(
attention_mask
,
-
1
).
cumsum
(
-
1
),
-
1
).
astype
(
"bool"
)
hidden_states
,
extract_features
=
self
.
feature_projection
(
extract_features
,
deterministic
=
deterministic
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment