Commit a75c64d8 authored by Lysandre's avatar Lysandre
Browse files

Black 20 release

parent e78c1103
......@@ -139,12 +139,12 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_length: bool = False,
verbose: bool = True,
) -> Dict[str, Any]:
""" Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
"""Convert the encoding representation (from low-level HuggingFace tokenizer output) to a python Dict.
Overflowing tokens are converted to additional examples (like batches) so the output values of
the dict are lists (overflows) of lists (tokens).
Overflowing tokens are converted to additional examples (like batches) so the output values of
the dict are lists (overflows) of lists (tokens).
Output shape: (overflows, sequence length)
Output shape: (overflows, sequence length)
"""
if return_token_type_ids is None:
return_token_type_ids = "token_type_ids" in self.model_input_names
......
......@@ -902,7 +902,12 @@ class XLMTokenizer(PreTrainedTokenizer):
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0,))
return list(
map(
lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0,
token_ids_0,
)
)
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
......
......@@ -141,10 +141,12 @@ class TrainingArguments:
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluate_during_training: bool = field(
default=False, metadata={"help": "Run evaluation during training at each logging step."},
default=False,
metadata={"help": "Run evaluation during training at each logging step."},
)
prediction_loss_only: bool = field(
default=False, metadata={"help": "When performing evaluation and predictions, only returns the loss."},
default=False,
metadata={"help": "When performing evaluation and predictions, only returns the loss."},
)
per_device_train_batch_size: int = field(
......
......@@ -100,7 +100,8 @@ class TFTrainingArguments(TrainingArguments):
"""
tpu_name: str = field(
default=None, metadata={"help": "Name of TPU"},
default=None,
metadata={"help": "Name of TPU"},
)
@cached_property
......
......@@ -703,10 +703,10 @@ def write_predictions_extended(
tokenizer,
verbose_logging,
):
""" XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
"""XLNet write prediction logic (more complex than Bert's).
Write final predictions to the json file and log-odds of null if needed.
Requires utils_squad_evaluate.py
Requires utils_squad_evaluate.py
"""
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction", ["feature_index", "start_index", "end_index", "start_log_prob", "end_log_prob"]
......
......@@ -31,47 +31,47 @@ XXX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
class XxxConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.XXXModel`.
It is used to instantiate a XXX model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the XXX `xxx-base-uncased <https://huggingface.co/xxx/xxx-base-uncased>`__ architecture.
This is the configuration class to store the configuration of a :class:`~transformers.XXXModel`.
It is used to instantiate a XXX model according to the specified arguments, defining the model
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
the XXX `xxx-base-uncased <https://huggingface.co/xxx/xxx-base-uncased>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.
Args:
vocab_size (:obj:`int`, optional, defaults to 30522):
Vocabulary size of the XXX model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.XXXModel`.
hidden_size (:obj:`int`, optional, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, optional, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
Args:
vocab_size (:obj:`int`, optional, defaults to 30522):
Vocabulary size of the XXX model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.XXXModel`.
hidden_size (:obj:`int`, optional, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, optional, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, optional, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, optional, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the :obj:`truncated_normal_initializer` for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-5):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
kwargs:
Additional arguments for common configurations, passed to :class:`~transformers.PretrainedConfig`.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, optional, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, optional, defaults to 2):
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
initializer_range (:obj:`float`, optional, defaults to 0.02):
The standard deviation of the :obj:`truncated_normal_initializer` for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-5):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
kwargs:
Additional arguments for common configurations, passed to :class:`~transformers.PretrainedConfig`.
"""
model_type = "xxx"
......
......@@ -223,7 +223,10 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
pooled_output = self.pooler(sequence_output)
if not return_dict:
return (sequence_output, pooled_output,) + encoder_outputs[1:]
return (
sequence_output,
pooled_output,
) + encoder_outputs[1:]
return TFBaseModelOutputWithPooling(
last_hidden_state=sequence_output,
......@@ -241,8 +244,8 @@ class TFXxxMainLayer(tf.keras.layers.Layer):
# pointers for your model.
####################################################
class TFXxxPreTrainedModel(TFPreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = XxxConfig
......@@ -422,7 +425,10 @@ class TFXxxForMaskedLM(TFXxxPreTrainedModel, TFMaskedLanguageModelingLoss):
return ((loss,) + output) if loss is not None else output
return TFMaskedLMOutput(
loss=loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -503,7 +509,10 @@ class TFXxxForSequenceClassification(TFXxxPreTrainedModel, TFSequenceClassificat
return ((loss,) + output) if loss is not None else output
return TFSequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -524,7 +533,7 @@ class TFXxxForMultipleChoice(TFXxxPreTrainedModel, TFMultipleChoiceLoss):
@property
def dummy_inputs(self):
""" Dummy inputs to build the network.
"""Dummy inputs to build the network.
Returns:
tf.Tensor with dummy inputs
......@@ -631,7 +640,10 @@ class TFXxxForMultipleChoice(TFXxxPreTrainedModel, TFMultipleChoiceLoss):
return ((loss,) + output) if loss is not None else output
return TFMultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -710,7 +722,10 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel, TFTokenClassificationLos
return ((loss,) + output) if loss is not None else output
return TFTokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -59,8 +59,7 @@ XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def load_tf_weights_in_xxx(model, config, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model.
"""
"""Load tf checkpoints in a pytorch model."""
try:
import re
......@@ -189,8 +188,8 @@ XxxPooler = nn.Module
class XxxPreTrainedModel(PreTrainedModel):
""" An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""An abstract class to handle weights initialization and
a simple interface for downloading and loading pretrained models.
"""
config_class = XxxConfig
......@@ -290,9 +289,9 @@ class XxxModel(XxxPreTrainedModel):
self.embeddings.word_embeddings = new_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
......@@ -517,7 +516,10 @@ class XxxForSequenceClassification(XxxPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -603,7 +605,10 @@ class XxxForMultipleChoice(XxxPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return MultipleChoiceModelOutput(
loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=reshaped_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......@@ -686,7 +691,10 @@ class XxxForTokenClassification(XxxPreTrainedModel):
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions,
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
......
......@@ -39,7 +39,8 @@ if is_torch_available():
class AlbertModelTester:
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 13
......
......@@ -54,7 +54,8 @@ PGE_ARTICLE = """ PG&E stated it scheduled the blackouts in response to forecast
@require_torch
class ModelTester:
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 13
......@@ -76,7 +77,9 @@ class ModelTester:
torch.manual_seed(0)
def prepare_config_and_inputs_for_common(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(3,)
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
)
input_ids[:, -1] = 2 # Eos Token
config = BartConfig(
......@@ -100,7 +103,9 @@ class ModelTester:
def prepare_bart_inputs_dict(
config, input_ids, attention_mask=None,
config,
input_ids,
attention_mask=None,
):
if attention_mask is None:
attention_mask = input_ids.ne(config.pad_token_id)
......@@ -261,7 +266,11 @@ class BartHeadTests(unittest.TestCase):
sequence_labels = ids_tensor([batch_size], 2).to(torch_device)
model = BartForQuestionAnswering(config)
model.to(torch_device)
outputs = model(input_ids=input_ids, start_positions=sequence_labels, end_positions=sequence_labels,)
outputs = model(
input_ids=input_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.assertEqual(outputs["start_logits"].shape, input_ids.shape)
self.assertEqual(outputs["end_logits"].shape, input_ids.shape)
......@@ -491,7 +500,11 @@ class BartModelIntegrationTests(unittest.TestCase):
EXPECTED_SUMMARY = "California's largest power company has begun shutting off electricity to thousands of customers in the state."
dct = tok.batch_encode_plus(
[PGE_ARTICLE], max_length=1024, padding="max_length", truncation=True, return_tensors="pt",
[PGE_ARTICLE],
max_length=1024,
padding="max_length",
truncation=True,
return_tensors="pt",
).to(torch_device)
hypotheses_batch = model.generate(
......@@ -506,7 +519,10 @@ class BartModelIntegrationTests(unittest.TestCase):
decoder_start_token_id=model.config.eos_token_id,
)
decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True,)
decoded = tok.batch_decode(
hypotheses_batch,
skip_special_tokens=True,
)
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
def test_xsum_config_generation_params(self):
......
......@@ -264,7 +264,10 @@ class BertModelTester:
model.to(torch_device)
model.eval()
result = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, next_sentence_label=sequence_labels,
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
next_sentence_label=sequence_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
......
......@@ -33,7 +33,9 @@ class CamembertModelIntegrationTest(unittest.TestCase):
model.to(torch_device)
input_ids = torch.tensor(
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]], device=torch_device, dtype=torch.long,
[[5, 121, 11, 660, 16, 730, 25543, 110, 83, 6]],
device=torch_device,
dtype=torch.long,
) # J'aime le camembert !
output = model(input_ids)["last_hidden_state"]
expected_shape = torch.Size((1, 10, 768))
......
......@@ -330,7 +330,9 @@ class ModelTesterMixin:
# Prepare head_mask
# Set require_grad after having prepared the tensor to avoid error (leaf variable has been moved into the graph interior)
head_mask = torch.ones(
self.model_tester.num_hidden_layers, self.model_tester.num_attention_heads, device=torch_device,
self.model_tester.num_hidden_layers,
self.model_tester.num_attention_heads,
device=torch_device,
)
head_mask[0, 0] = 0
head_mask[-1, :-1] = 0
......@@ -370,7 +372,10 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
(
config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
......@@ -399,7 +404,10 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
(
config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
......@@ -432,7 +440,10 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
(
config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
......@@ -463,7 +474,10 @@ class ModelTesterMixin:
return
for model_class in self.all_model_classes:
(config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
(
config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if "head_mask" in inputs_dict:
del inputs_dict["head_mask"]
......@@ -534,7 +548,8 @@ class ModelTesterMixin:
seq_length = self.model_tester.seq_length
self.assertListEqual(
list(hidden_states[0].shape[-2:]), [seq_length, self.model_tester.hidden_size],
list(hidden_states[0].shape[-2:]),
[seq_length, self.model_tester.hidden_size],
)
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -550,7 +565,10 @@ class ModelTesterMixin:
check_hidden_states_output(inputs_dict, config, model_class)
def test_feed_forward_chunking(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
torch.manual_seed(0)
config = copy.deepcopy(original_config)
......@@ -570,7 +588,10 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
def test_resize_tokens_embeddings(self):
(original_config, inputs_dict,) = self.model_tester.prepare_config_and_inputs_for_common()
(
original_config,
inputs_dict,
) = self.model_tester.prepare_config_and_inputs_for_common()
if not self.test_resize_embeddings:
return
......@@ -844,7 +865,14 @@ class ModelTesterMixin:
model.generate(input_ids, do_sample=False, num_return_sequences=3, num_beams=2)
# num_return_sequences > 1, sample
self._check_generated_ids(model.generate(input_ids, do_sample=True, num_beams=2, num_return_sequences=2,))
self._check_generated_ids(
model.generate(
input_ids,
do_sample=True,
num_beams=2,
num_return_sequences=2,
)
)
# num_return_sequences > 1, greedy
self._check_generated_ids(model.generate(input_ids, do_sample=False, num_beams=2, num_return_sequences=2))
......
......@@ -30,7 +30,8 @@ if is_torch_available():
class CTRLModelTester:
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 14
......
......@@ -179,7 +179,9 @@ if is_torch_available():
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
result = model(
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
......
......@@ -149,7 +149,10 @@ class DPRModelTester:
model = DPRReader(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask,)
result = model(
input_ids,
attention_mask=input_mask,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
......@@ -173,7 +176,15 @@ class DPRModelTester:
@require_torch
class DPRModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (DPRContextEncoder, DPRQuestionEncoder, DPRReader,) if is_torch_available() else ()
all_model_classes = (
(
DPRContextEncoder,
DPRQuestionEncoder,
DPRReader,
)
if is_torch_available()
else ()
)
test_resize_embeddings = False
test_missing_keys = False # why?
......
......@@ -39,7 +39,8 @@ if is_torch_available():
class ElectraModelTester:
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 13
......
......@@ -391,7 +391,11 @@ class EncoderDecoderMixin:
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
with torch.no_grad():
outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,)
outputs = model_2(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
......@@ -401,7 +405,9 @@ class EncoderDecoderMixin:
model_1.to(torch_device)
after_outputs = model_1(
input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
......
......@@ -39,7 +39,8 @@ if is_torch_available():
class FlaubertModelTester(object):
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 13
......
......@@ -244,7 +244,8 @@ class GPT2ModelTester:
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment