Unverified Commit a28325e2 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

Replace python random with torch.rand to enable dynamo.export (#24434)

* Replace python random with torch.rand to enable dynamo.export

* revert changes to flax model code

* Remove unused random import

* Fix torch template

* Move torch.manual_seed(0) to right location
parent c036c814
......@@ -1560,7 +1560,6 @@ class {{cookiecutter.camelcase_modelname}}ForQuestionAnswering({{cookiecutter.ca
{% else %}
import math
import copy
import random
from typing import Optional, Tuple, List, Union
import torch
......@@ -2306,7 +2305,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
dropout_probability = torch.randn([])
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
......@@ -2543,7 +2542,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
dropout_probability = torch.randn([])
if self.training and (dropout_probability < self.layerdrop):
continue
......
......@@ -464,6 +464,7 @@ class GenerationTesterMixin:
**model_kwargs,
)
# beam_search does not automatically interleave `batch_size` dim for `num_beams * num_return_sequences`
torch.manual_seed(0)
kwargs = {}
if model.config.is_encoder_decoder:
encoder_outputs, input_ids, attention_mask = self._get_encoder_outputs(
......@@ -482,7 +483,6 @@ class GenerationTesterMixin:
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
torch.manual_seed(0)
with torch.no_grad():
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
output_beam_sample = model.beam_sample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment