"tests/rembert/test_modeling_rembert.py" did not exist on "3425936643b157bda181af169b371dcf0a3ad3eb"
Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -5,7 +5,7 @@ target-version = ['py37'] ...@@ -5,7 +5,7 @@ target-version = ['py37']
[tool.ruff] [tool.ruff]
# Never enforce `E501` (line length violations). # Never enforce `E501` (line length violations).
ignore = ["E501", "E741", "W605"] ignore = ["E501", "E741", "W605"]
select = ["E", "F", "I", "W"] select = ["C", "E", "F", "I", "W"]
line-length = 119 line-length = 119
# Ignore import violations in all `__init__.py` files. # Ignore import violations in all `__init__.py` files.
......
...@@ -557,9 +557,9 @@ def stop_memory_tracing( ...@@ -557,9 +557,9 @@ def stop_memory_tracing(
cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc cumulative_memory_dict[frame][2] += cpu_gpu_mem_inc
cumulative_memory = sorted( cumulative_memory = sorted(
list(cumulative_memory_dict.items()), key=lambda x: x[1][2], reverse=True cumulative_memory_dict.items(), key=lambda x: x[1][2], reverse=True
) # order by the total CPU + GPU memory increase ) # order by the total CPU + GPU memory increase
cumulative_memory = list( cumulative_memory = [
MemoryState( MemoryState(
frame=frame, frame=frame,
cpu=Memory(cpu_mem_inc), cpu=Memory(cpu_mem_inc),
...@@ -567,7 +567,7 @@ def stop_memory_tracing( ...@@ -567,7 +567,7 @@ def stop_memory_tracing(
cpu_gpu=Memory(cpu_gpu_mem_inc), cpu_gpu=Memory(cpu_gpu_mem_inc),
) )
for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory for frame, (cpu_mem_inc, gpu_mem_inc, cpu_gpu_mem_inc) in cumulative_memory
) ]
memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True) memory_curr_trace = sorted(memory_curr_trace, key=lambda x: x.cpu_gpu.bytes, reverse=True)
......
...@@ -324,7 +324,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -324,7 +324,7 @@ class PretrainedConfig(PushToHubMixin):
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}." f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
) )
self.id2label = dict((int(key), value) for key, value in self.id2label.items()) self.id2label = {int(key): value for key, value in self.id2label.items()}
# Keys are always strings in JSON so convert ids to int here. # Keys are always strings in JSON so convert ids to int here.
else: else:
self.num_labels = kwargs.pop("num_labels", 2) self.num_labels = kwargs.pop("num_labels", 2)
...@@ -696,7 +696,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -696,7 +696,7 @@ class PretrainedConfig(PushToHubMixin):
config = cls(**config_dict) config = cls(**config_dict)
if hasattr(config, "pruned_heads"): if hasattr(config, "pruned_heads"):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) config.pruned_heads = {int(key): value for key, value in config.pruned_heads.items()}
# Update config with kwargs if needed # Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs: if "num_labels" in kwargs and "id2label" in kwargs:
......
...@@ -367,13 +367,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf ...@@ -367,13 +367,13 @@ def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None, inf
# keep for quick debug: # keep for quick debug:
# from pprint import pprint; pprint(config) # from pprint import pprint; pprint(config)
kwargs = dict( kwargs = {
model=model, "model": model,
model_parameters=model_parameters, "model_parameters": model_parameters,
config_params=config, "config_params": config,
optimizer=optimizer, "optimizer": optimizer,
lr_scheduler=lr_scheduler, "lr_scheduler": lr_scheduler,
) }
deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) deepspeed_engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs)
......
...@@ -188,7 +188,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin): ...@@ -188,7 +188,7 @@ class SequenceFeatureExtractor(FeatureExtractionMixin):
truncated_inputs = [] truncated_inputs = []
for i in range(batch_size): for i in range(batch_size):
inputs = dict((k, v[i]) for k, v in processed_features.items()) inputs = {k: v[i] for k, v in processed_features.items()}
# truncation # truncation
inputs_slice = self._truncate( inputs_slice = self._truncate(
inputs, inputs,
......
...@@ -208,12 +208,12 @@ class DisjunctiveTrie: ...@@ -208,12 +208,12 @@ class DisjunctiveTrie:
""" """
self.max_height = max([len(one) for one in nested_token_ids]) self.max_height = max([len(one) for one in nested_token_ids])
root = dict() root = {}
for token_ids in nested_token_ids: for token_ids in nested_token_ids:
level = root level = root
for tidx, token_id in enumerate(token_ids): for tidx, token_id in enumerate(token_ids):
if token_id not in level: if token_id not in level:
level[token_id] = dict() level[token_id] = {}
level = level[token_id] level = level[token_id]
......
...@@ -951,7 +951,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): ...@@ -951,7 +951,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]): for k in range(input_ids.shape[0]):
seq = [t for t in input_ids[k, self.begin_index :].tolist()] seq = list(input_ids[k, self.begin_index :].tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
......
...@@ -115,7 +115,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: ...@@ -115,7 +115,7 @@ def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]:
if is_valid_image(images): if is_valid_image(images):
if images.ndim == expected_ndims + 1: if images.ndim == expected_ndims + 1:
# Batch of images # Batch of images
images = [image for image in images] images = list(images)
elif images.ndim == expected_ndims: elif images.ndim == expected_ndims:
# Single image # Single image
images = [images] images = [images]
......
...@@ -365,7 +365,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -365,7 +365,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
name="huggingface-tune", name="huggingface-tune",
type="offline", type="offline",
parameters=trainer.hp_space(None), parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")], metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
parallel_bandwidth=1, parallel_bandwidth=1,
budget=n_trials, budget=n_trials,
) )
...@@ -402,7 +402,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -402,7 +402,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
experiment = conn.experiments().create( experiment = conn.experiments().create(
name="huggingface-tune", name="huggingface-tune",
parameters=trainer.hp_space(None), parameters=trainer.hp_space(None),
metrics=[dict(name="objective", objective=direction, strategy="optimize")], metrics=[{"name": "objective", "objective": direction, "strategy": "optimize"}],
parallel_bandwidth=1, parallel_bandwidth=1,
observation_budget=n_trials, observation_budget=n_trials,
project="huggingface", project="huggingface",
...@@ -425,7 +425,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be ...@@ -425,7 +425,7 @@ def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> Be
metrics = trainer.evaluate() metrics = trainer.evaluate()
trainer.objective = trainer.compute_objective(metrics) trainer.objective = trainer.compute_objective(metrics)
values = [dict(name="objective", value=trainer.objective)] values = [{"name": "objective", "value": trainer.objective}]
obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values) obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]") logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
experiment = conn.experiments(experiment.id).fetch() experiment = conn.experiments(experiment.id).fetch()
......
...@@ -162,7 +162,7 @@ class KerasMetricCallback(Callback): ...@@ -162,7 +162,7 @@ class KerasMetricCallback(Callback):
def _postprocess_predictions_or_labels(self, inputs): def _postprocess_predictions_or_labels(self, inputs):
if isinstance(inputs[0], dict): if isinstance(inputs[0], dict):
outputs = dict() outputs = {}
for key in inputs[0].keys(): for key in inputs[0].keys():
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs]) outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
# If it's a dict with only one key, just return the array # If it's a dict with only one key, just return the array
......
...@@ -677,7 +677,7 @@ class TrainingSummary: ...@@ -677,7 +677,7 @@ class TrainingSummary:
_, eval_lines, eval_results = parse_keras_history(keras_history) _, eval_lines, eval_results = parse_keras_history(keras_history)
else: else:
eval_lines = [] eval_lines = []
eval_results = dict() eval_results = {}
hyperparameters = extract_hyperparameters_from_keras(model) hyperparameters = extract_hyperparameters_from_keras(model)
return cls( return cls(
...@@ -706,7 +706,7 @@ def parse_keras_history(logs): ...@@ -706,7 +706,7 @@ def parse_keras_history(logs):
# This looks like a `History` object # This looks like a `History` object
if not hasattr(logs, "epoch"): if not hasattr(logs, "epoch"):
# This history looks empty, return empty results # This history looks empty, return empty results
return None, [], dict() return None, [], {}
logs.history["epoch"] = logs.epoch logs.history["epoch"] = logs.epoch
logs = logs.history logs = logs.history
else: else:
...@@ -716,7 +716,7 @@ def parse_keras_history(logs): ...@@ -716,7 +716,7 @@ def parse_keras_history(logs):
lines = [] lines = []
for i in range(len(logs["epoch"])): for i in range(len(logs["epoch"])):
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()} epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
values = dict() values = {}
for k, v in epoch_dict.items(): for k, v in epoch_dict.items():
if k.startswith("val_"): if k.startswith("val_"):
k = "validation_" + k[4:] k = "validation_" + k[4:]
...@@ -797,7 +797,7 @@ def parse_log_history(log_history): ...@@ -797,7 +797,7 @@ def parse_log_history(log_history):
def extract_hyperparameters_from_keras(model): def extract_hyperparameters_from_keras(model):
import tensorflow as tf import tensorflow as tf
hyperparameters = dict() hyperparameters = {}
if hasattr(model, "optimizer") and model.optimizer is not None: if hasattr(model, "optimizer") and model.optimizer is not None:
hyperparameters["optimizer"] = model.optimizer.get_config() hyperparameters["optimizer"] = model.optimizer.get_config()
else: else:
......
...@@ -76,7 +76,7 @@ def rename_key_and_reshape_tensor( ...@@ -76,7 +76,7 @@ def rename_key_and_reshape_tensor(
def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool: def is_key_or_prefix_key_in_dict(key: Tuple[str]) -> bool:
"""Checks if `key` of `(prefix,) + key` is in random_flax_state_dict""" """Checks if `key` of `(prefix,) + key` is in random_flax_state_dict"""
return len(set(random_flax_state_dict) & set([key, (model_prefix,) + key])) > 0 return len(set(random_flax_state_dict) & {key, (model_prefix,) + key}) > 0
# layer norm # layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
...@@ -122,10 +122,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -122,10 +122,10 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict = {} flax_state_dict = {}
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
...@@ -179,10 +179,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): ...@@ -179,10 +179,10 @@ def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
random_flax_state_dict = flatten_dict(flax_model.params) random_flax_state_dict = flatten_dict(flax_model.params)
load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and (
model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) model_prefix in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( load_base_model_into_model_with_head = (model_prefix in flax_model.params) and (
model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) model_prefix not in {k.split(".")[0] for k in pt_state_dict.keys()}
) )
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items(): for pt_key, pt_tensor in pt_state_dict.items():
...@@ -267,10 +267,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state): ...@@ -267,10 +267,10 @@ def load_flax_weights_in_pytorch_model(pt_model, flax_state):
pt_model_dict = pt_model.state_dict() pt_model_dict = pt_model.state_dict()
load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and ( load_model_with_head_into_base_model = (pt_model.base_model_prefix in flax_state) and (
pt_model.base_model_prefix not in set([k.split(".")[0] for k in pt_model_dict.keys()]) pt_model.base_model_prefix not in {k.split(".")[0] for k in pt_model_dict.keys()}
) )
load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and ( load_base_model_into_model_with_head = (pt_model.base_model_prefix not in flax_state) and (
pt_model.base_model_prefix in set([k.split(".")[0] for k in pt_model_dict.keys()]) pt_model.base_model_prefix in {k.split(".")[0] for k in pt_model_dict.keys()}
) )
# keep track of unexpected & missing keys # keep track of unexpected & missing keys
......
...@@ -440,7 +440,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -440,7 +440,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
""" """
# Load the index # Load the index
state_sharded_dict = dict() state_sharded_dict = {}
for shard_file in shard_files: for shard_file in shard_files:
# load using msgpack utils # load using msgpack utils
...@@ -708,19 +708,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): ...@@ -708,19 +708,19 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
cached_file_kwargs = dict( cached_file_kwargs = {
cache_dir=cache_dir, "cache_dir": cache_dir,
force_download=force_download, "force_download": force_download,
proxies=proxies, "proxies": proxies,
resume_download=resume_download, "resume_download": resume_download,
local_files_only=local_files_only, "local_files_only": local_files_only,
use_auth_token=use_auth_token, "use_auth_token": use_auth_token,
user_agent=user_agent, "user_agent": user_agent,
revision=revision, "revision": revision,
subfolder=subfolder, "subfolder": subfolder,
_raise_exceptions_for_missing_entries=False, "_raise_exceptions_for_missing_entries": False,
_commit_hash=commit_hash, "_commit_hash": commit_hash,
) }
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an expection but a None
......
...@@ -258,7 +258,7 @@ def load_pytorch_state_dict_in_tf2_model( ...@@ -258,7 +258,7 @@ def load_pytorch_state_dict_in_tf2_model(
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
tf_loaded_numel = 0 tf_loaded_numel = 0
weight_value_tuples = [] weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys())) all_pytorch_weights = set(pt_state_dict.keys())
missing_keys = [] missing_keys = []
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
...@@ -425,7 +425,7 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_ ...@@ -425,7 +425,7 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
) )
tf_weights_map[pt_name] = (tf_weight, transpose) tf_weights_map[pt_name] = (tf_weight, transpose)
all_tf_weights = set(list(tf_weights_map.keys())) all_tf_weights = set(tf_weights_map.keys())
loaded_pt_weights_data_ptr = {} loaded_pt_weights_data_ptr = {}
missing_keys_pt = [] missing_keys_pt = []
for pt_weight_name, pt_weight in current_pt_params_dict.items(): for pt_weight_name, pt_weight in current_pt_params_dict.items():
......
...@@ -584,7 +584,7 @@ def input_processing(func, config, **kwargs): ...@@ -584,7 +584,7 @@ def input_processing(func, config, **kwargs):
if "kwargs" in output: if "kwargs" in output:
del output["kwargs"] del output["kwargs"]
cast_output = dict() cast_output = {}
for key, val in output.items(): for key, val in output.items():
if isinstance(val, tf.Tensor) and val.dtype == tf.int64: if isinstance(val, tf.Tensor) and val.dtype == tf.int64:
cast_output[key] = tf.cast(val, tf.int32) cast_output[key] = tf.cast(val, tf.int32)
...@@ -737,7 +737,7 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s ...@@ -737,7 +737,7 @@ def load_tf_sharded_weights(model, shard_files, ignore_mismatched_sizes=False, s
# Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load # Since TF adds the name of the class to its weights, and uses the index and not the name of the layer to load
# the weight, we have to get rid of the first prefix of the name of the layer. # the weight, we have to get rid of the first prefix of the name of the layer.
model_keys = set() model_keys = set()
model_layer_map = dict() model_layer_map = {}
for i, k in enumerate(model.weights): for i, k in enumerate(model.weights):
if "model." in k.name or len(k.name.split("/")) == 1: if "model." in k.name or len(k.name.split("/")) == 1:
layer_name = k.name layer_name = k.name
...@@ -901,10 +901,10 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size ...@@ -901,10 +901,10 @@ def load_tf_weights_from_h5(model, resolved_archive_file, ignore_mismatched_size
) )
# Find the missing layers from the high level list of layers # Find the missing layers from the high level list of layers
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name) missing_layers = list({layer.name for layer in model.layers} - saved_h5_model_layers_name)
# Find the unexpected layers from the high level list of layers # Find the unexpected layers from the high level list of layers
unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers])) unexpected_layers = list(saved_h5_model_layers_name - {layer.name for layer in model.layers})
saved_weight_names_set = set() saved_weight_names_set = set()
symbolic_weights_names = set() symbolic_weights_names = set()
weight_value_tuples = [] weight_value_tuples = []
...@@ -1349,7 +1349,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1349,7 +1349,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else: else:
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np") collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="np")
if collate_fn_args is None: if collate_fn_args is None:
collate_fn_args = dict() collate_fn_args = {}
if not isinstance(dataset, datasets.Dataset): if not isinstance(dataset, datasets.Dataset):
raise TypeError("Dataset argument should be a datasets.Dataset!") raise TypeError("Dataset argument should be a datasets.Dataset!")
...@@ -1471,7 +1471,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -1471,7 +1471,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
elif "mc_labels" in arg_names: elif "mc_labels" in arg_names:
return {"labels": "logits", "mc_labels": "mc_logits"} return {"labels": "logits", "mc_labels": "mc_logits"}
else: else:
return dict() return {}
def train_step(self, data): def train_step(self, data):
""" """
...@@ -2613,19 +2613,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu ...@@ -2613,19 +2613,19 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
cached_file_kwargs = dict( cached_file_kwargs = {
cache_dir=cache_dir, "cache_dir": cache_dir,
force_download=force_download, "force_download": force_download,
proxies=proxies, "proxies": proxies,
resume_download=resume_download, "resume_download": resume_download,
local_files_only=local_files_only, "local_files_only": local_files_only,
use_auth_token=use_auth_token, "use_auth_token": use_auth_token,
user_agent=user_agent, "user_agent": user_agent,
revision=revision, "revision": revision,
subfolder=subfolder, "subfolder": subfolder,
_raise_exceptions_for_missing_entries=False, "_raise_exceptions_for_missing_entries": False,
_commit_hash=commit_hash, "_commit_hash": commit_hash,
) }
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
......
...@@ -1271,7 +1271,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1271,7 +1271,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
len(encoder_modules) > 0 len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()]) all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
encoder_layer_pos = 0 encoder_layer_pos = 0
for name, module in decoder_modules.items(): for name, module in decoder_modules.items():
if name.isdigit(): if name.isdigit():
...@@ -2304,19 +2304,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2304,19 +2304,19 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
cached_file_kwargs = dict( cached_file_kwargs = {
cache_dir=cache_dir, "cache_dir": cache_dir,
force_download=force_download, "force_download": force_download,
proxies=proxies, "proxies": proxies,
resume_download=resume_download, "resume_download": resume_download,
local_files_only=local_files_only, "local_files_only": local_files_only,
use_auth_token=use_auth_token, "use_auth_token": use_auth_token,
user_agent=user_agent, "user_agent": user_agent,
revision=revision, "revision": revision,
subfolder=subfolder, "subfolder": subfolder,
_raise_exceptions_for_missing_entries=False, "_raise_exceptions_for_missing_entries": False,
_commit_hash=commit_hash, "_commit_hash": commit_hash,
) }
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
# Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
...@@ -2474,7 +2474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2474,7 +2474,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if is_sharded: if is_sharded:
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]
else: else:
loaded_state_dict_keys = [k for k in state_dict.keys()] loaded_state_dict_keys = list(state_dict.keys())
if low_cpu_mem_usage or use_keep_in_fp32_modules: if low_cpu_mem_usage or use_keep_in_fp32_modules:
state_dict = None state_dict = None
...@@ -3046,12 +3046,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3046,12 +3046,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names]) module_keys = {".".join(key.split(".")[:-1]) for key in names}
# torch.nn.ParameterList is a special case where two parameter keywords # torch.nn.ParameterList is a special case where two parameter keywords
# are appended to the module name, *e.g.* bert.special_embeddings.0 # are appended to the module name, *e.g.* bert.special_embeddings.0
module_keys = module_keys.union( module_keys = module_keys.union(
set([".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()]) {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
) )
retrieved_modules = [] retrieved_modules = []
......
...@@ -555,7 +555,7 @@ class FlaxBeitEncoder(nn.Module): ...@@ -555,7 +555,7 @@ class FlaxBeitEncoder(nn.Module):
) )
# stochastic depth decay rule # stochastic depth decay rule
drop_path_rates = [x for x in np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers)] drop_path_rates = list(np.linspace(0, self.config.drop_path_rate, self.config.num_hidden_layers))
self.layer = FlaxBeitLayerCollection( self.layer = FlaxBeitLayerCollection(
self.config, self.config,
window_size=self.window_size, window_size=self.window_size,
......
...@@ -318,7 +318,7 @@ class BertweetTokenizer(PreTrainedTokenizer): ...@@ -318,7 +318,7 @@ class BertweetTokenizer(PreTrainedTokenizer):
split_tokens = [] split_tokens = []
words = re.findall(r"\S+\n?", text) words = re.findall(r"\S+\n?", text)
for token in words: for token in words:
split_tokens.extend([t for t in self.bpe(token).split(" ")]) split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens return split_tokens
def normalizeTweet(self, tweet): def normalizeTweet(self, tweet):
...@@ -726,7 +726,7 @@ class TweetTokenizer: ...@@ -726,7 +726,7 @@ class TweetTokenizer:
words = WORD_RE.findall(safe_text) words = WORD_RE.findall(safe_text)
# Possibly alter the case, but avoid changing emoticons like :D into :d: # Possibly alter the case, but avoid changing emoticons like :D into :d:
if not self.preserve_case: if not self.preserve_case:
words = list(map((lambda x: x if EMOTICON_RE.search(x) else x.lower()), words)) words = [x if EMOTICON_RE.search(x) else x.lower() for x in words]
return words return words
......
...@@ -202,7 +202,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast): ...@@ -202,7 +202,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
"You should not supply a second sequence if the provided sequence of " "You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model." "ids is already formatted with special tokens for the model."
) )
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) return [1 if x in [self.sep_token_id, self.cls_token_id] else 0 for x in token_ids_0]
if token_ids_1 is None: if token_ids_1 is None:
return [1] + ([0] * len(token_ids_0)) + [1] return [1] + ([0] * len(token_ids_0)) + [1]
......
...@@ -132,8 +132,8 @@ class BioGptTokenizer(PreTrainedTokenizer): ...@@ -132,8 +132,8 @@ class BioGptTokenizer(PreTrainedTokenizer):
self.lang = "en" self.lang = "en"
self.sm = sacremoses self.sm = sacremoses
# cache of sm.MosesTokenizer instance # cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict() self.cache_moses_tokenizer = {}
self.cache_moses_detokenizer = dict() self.cache_moses_detokenizer = {}
""" Initialisation""" """ Initialisation"""
with open(vocab_file, encoding="utf-8") as vocab_handle: with open(vocab_file, encoding="utf-8") as vocab_handle:
...@@ -221,7 +221,7 @@ class BioGptTokenizer(PreTrainedTokenizer): ...@@ -221,7 +221,7 @@ class BioGptTokenizer(PreTrainedTokenizer):
split_tokens = [] split_tokens = []
for token in text: for token in text:
if token: if token:
split_tokens.extend([t for t in self.bpe(token).split(" ")]) split_tokens.extend(list(self.bpe(token).split(" ")))
return split_tokens return split_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment