Commit 5988d2cc authored by yuguo960516's avatar yuguo960516
Browse files

bert-large

parent 478602ba
Pipeline #142 canceled with stages
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import OrderedDict
from nltk.translate.bleu_score import corpus_bleu
from libai.utils import distributed as dist
from .evaluator import DatasetEvaluator
class BLEUEvaluator(DatasetEvaluator):
"""
Evaluate BLEU(Bilingual Evaluation Understudy) score.
BLEU is a score for comparing a candidate translation
of text to one or more reference translations.
"""
def __init__(self):
super().__init__()
self._predictions = []
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
candidate = outputs["candidate"]
reference = inputs["reference"]
self._predictions.append({"candidate": candidate, "reference": reference})
def evaluate(self):
if not dist.is_main_process():
return {}
else:
predictions = self._predictions
candidates = []
references = []
for pred in predictions:
candidates.append(pred["candidate"])
references.append(pred["reference"])
bleu_score = corpus_bleu(references, candidates)
self._results = OrderedDict()
self._results["bleu_score"] = bleu_score
return copy.deepcopy(self._results)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections import OrderedDict
from libai.utils import distributed as dist
from .evaluator import DatasetEvaluator
def accuracy(output, target, topk=(1,)):
maxk = min(max(topk), output.size()[1])
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
return [
(correct[: min(k, maxk)].reshape(-1).float().sum(0) * 100.0 / batch_size).item()
for k in topk
]
class ClsEvaluator(DatasetEvaluator):
"""
Evaluate accuracy for classification.
The metrics range from 0 to 100 (instead of 0 to 1).
We support evaluate different topk accuracy.
You can reset `cfg.train.topk=(1, 5, N)` according to your needs.
"""
def __init__(self, topk=(1, 5)):
self.topk = topk
self._predictions = []
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
pred_logits = outputs["prediction_scores"]
labels = inputs["labels"]
# measure accuracy
topk_acc = accuracy(pred_logits, labels, topk=self.topk)
num_correct_acc_topk = [acc * labels.size(0) / 100 for acc in topk_acc]
self._predictions.append(
{"num_correct_topk": num_correct_acc_topk, "num_samples": labels.size(0)}
)
def evaluate(self):
if not dist.is_main_process():
return {}
else:
predictions = self._predictions
total_correct_num = OrderedDict()
for top_k in self.topk:
total_correct_num["Acc@" + str(top_k)] = 0
total_samples = 0
for prediction in predictions:
for top_k, num_correct_n in zip(self.topk, prediction["num_correct_topk"]):
total_correct_num["Acc@" + str(top_k)] += int(num_correct_n)
total_samples += int(prediction["num_samples"])
self._results = OrderedDict()
for top_k, topk_correct_num in total_correct_num.items():
self._results[top_k] = topk_correct_num / total_samples * 100
return copy.deepcopy(self._results)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import time
from collections import OrderedDict, abc
from contextlib import ExitStack, contextmanager
from typing import Callable, List, Union
import oneflow as flow
from libai.utils import distributed as dist
from libai.utils.logger import log_every_n_seconds
from .utils import pad_batch
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/evaluation/evaluator.py
# --------------------------------------------------------
class DatasetEvaluator:
"""
Base class for a dataset evaluator.
The function :func:`inference_on_dataset` runs the model over
all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
This class will accumulate information of the inputs/outputs (by :meth:`process`),
and produce evaluation results in the end (by :meth:`evaluate`).
"""
def reset(self):
"""
Preparation for a new round of evaluation.
Should be called before starting a round of evaluation.
"""
def process(self, inputs, outputs):
"""
Process the pair of inputs and outputs.
.. code-block:: python
pred_logits = outputs["prediction_scores"]
labels = inputs["labels"]
# do evaluation on pred_logits/labels pair
...
Args:
inputs (dict): the inputs that's used to call the model.
outputs (dict): the return dict of `model(**inputs)`
"""
def evaluate(self):
"""
Evaluate/summarize the performance after processing all input/output pairs.
Returns:
dict:
A new evaluator class can return a dict of arbitrary format
as long as the user can process the results.
In our train_net.py, we expect the following format:
* key: the name of the task (e.g., Classification)
* value: a dict of {metric name: score}, e.g.: {"Acc@1": 75.0}
"""
class DatasetEvaluators(DatasetEvaluator):
"""
Wrapper class to combine multiple :class:`DatasetEvaluator` instances.
This class dispatches every evaluation call to
all of its :class:`DatasetEvaluator`.
"""
def __init__(self, evaluators):
"""
Args:
evaluators (list): the evaluators to combine.
"""
super().__init__()
self._evaluators = evaluators
def reset(self):
for evaluator in self._evaluators:
evaluator.reset()
def process(self, inputs, outputs):
for evaluator in self._evaluators:
evaluator.process(inputs, outputs)
def evaluate(self):
results = OrderedDict()
for evaluator in self._evaluators:
result = evaluator.evaluate()
if dist.is_main_process() and result is not None:
for k, v in result.items():
assert (
k not in results
), "Different evaluators produce results with the same key {}".format(k)
results[k] = v
return results
def inference_on_dataset(
model,
data_loader,
batch_size,
eval_iter,
get_batch: Callable,
input_placement_device: str,
evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None],
):
"""
Run model on the data_loader and evaluate the metrics with evaluator.
Also benchmark the inference speed of `model.__call__` accurately.
The model will be used in eval mode.
Args:
model (callable): a callable which takes an object from
`data_loader` and returns some outputs.
If it's an nn.Module, it will be temporarily set to `eval` mode.
If you wish to evaluate a model in `training` mode instead, you can
wrap the given model and override its behavior of `.eval()` and `.train()`.
batch_size: batch size for inference
data_loader: an iterable object with a length.
The elements it generates will be the inputs to the model.
eval_iter: running steps for evaluation
get_batch: a Callable function for getting data from dataloader
input_placement_device: used in get_batch, set it to `cuda` or `cpu`.
see input_placement_device in `libai.configs.common.train.py` for more details.
evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
but don't want to do any evaluation.
Returns:
The return value of `evaluator.evaluate()`
"""
num_devices = dist.get_world_size()
logger = logging.getLogger(__name__)
total_samples = len(data_loader.dataset) # inference data loader must have a fixed length
if evaluator is None:
# create a no-op evaluator
evaluator = DatasetEvaluators([])
if isinstance(evaluator, abc.MutableSequence):
evaluator = DatasetEvaluators(evaluator)
evaluator.reset()
num_warmup = min(5, len(data_loader) - 1)
start_time = time.perf_counter()
total_data_time = 0
total_compute_time = 0
total_eval_time = 0
consumed_samples = 0
dps = dist.get_data_parallel_size()
last_batch_lack = (dps - (total_samples % dps)) % dps
# reset total samples
real_eval_iter = min(eval_iter, len(data_loader))
total_samples = min(real_eval_iter * batch_size, len(data_loader.dataset))
logger.info(
f"with eval_iter {eval_iter}, "
f"reset total samples {len(data_loader.dataset)} to {total_samples}"
)
logger.info(f"Start inference on {total_samples} samples")
with ExitStack() as stack:
if isinstance(model, (flow.nn.Module, flow.nn.Graph)):
stack.enter_context(inference_context(model))
stack.enter_context(flow.no_grad())
start_data_time = time.perf_counter()
for idx, inputs in enumerate(data_loader):
if idx >= real_eval_iter:
break
total_data_time += time.perf_counter() - start_data_time
if idx == num_warmup:
start_time = time.perf_counter()
total_data_time = 0
total_compute_time = 0
total_eval_time = 0
start_compute_time = time.perf_counter()
# model forward
data = get_batch(inputs, input_placement_device)
is_last_batch = idx == len(data_loader) - 1
paded_data, valid_sample = pad_batch(data, batch_size, last_batch_lack, is_last_batch)
outputs = model(**paded_data)
# get valid sample
valid_data = {
key: dist.tensor_to_rank0(value, to_local=True)[:valid_sample]
for key, value in data.items()
}
valid_outputs = {}
for key, value in outputs.items():
value = dist.tensor_to_rank0(value, to_local=True)
if value.ndim > 1:
valid_outputs[key] = value[:valid_sample] # Slice if it's batched output
else:
valid_outputs[key] = value
if flow.cuda.is_available():
dist.synchronize()
total_compute_time += time.perf_counter() - start_compute_time
start_eval_time = time.perf_counter()
if dist.is_main_process():
evaluator.process(valid_data, valid_outputs)
dist.synchronize()
total_eval_time += time.perf_counter() - start_eval_time
consumed_samples += valid_sample
iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
data_seconds_per_iter = total_data_time / iters_after_start
compute_seconds_per_iter = total_compute_time / iters_after_start
eval_seconds_per_iter = total_eval_time / iters_after_start
total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
eta = datetime.timedelta(
seconds=int(total_seconds_per_iter * (total_samples // batch_size - idx - 1))
)
log_every_n_seconds(
logging.INFO,
(
f"Inference done {consumed_samples}/{total_samples}. "
f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
f"Total: {total_seconds_per_iter:.4f} s/iter. "
f"ETA={eta}"
),
n=5,
)
start_data_time = time.perf_counter()
# Measure the time only for this worker (before the synchronization barrier)
total_time = time.perf_counter() - start_time
total_time_str = str(datetime.timedelta(seconds=total_time))
# NOTE this format is parsed by grep
logger.info("Total valid samples: {}".format(consumed_samples))
logger.info(
"Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
total_time_str, total_time / (total_samples - num_warmup), num_devices
)
)
total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
logger.info(
"Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
total_compute_time_str,
total_compute_time / (total_samples - num_warmup),
num_devices,
)
)
results = evaluator.evaluate()
# An evaluator may return None when not in main process.
# Replace it by an empty dict instead to make it easier for downstream code to handle
if results is None:
results = {}
return results
@contextmanager
def inference_context(model):
"""
A context where the model is temporarily changed to eval mode,
and restored to previous mode afterwards.
Args:
model: eager or graph mode in oneflow
"""
training_mode = model.model.training if isinstance(model, flow.nn.Graph) else model.training
if isinstance(model, flow.nn.Graph):
model.model.eval()
else:
model.eval()
yield
if isinstance(model, flow.nn.Graph):
model.model.train(training_mode)
else:
model.train(training_mode)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import math
from collections import OrderedDict
from libai.utils import distributed as dist
from .evaluator import DatasetEvaluator
class PPLEvaluator(DatasetEvaluator):
"""
Evaluate perplexity for Language Model.
Perplexity is a measurement of how well a probability distribution or
probability model predicts a sample.
"""
def __init__(self):
self._predictions = []
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
for k, v in outputs.items():
ppl = math.exp(min(20, v.item()))
self._predictions.append({f"{k}_PPL": ppl})
def evaluate(self):
if not dist.is_main_process():
return {}
else:
predictions = self._predictions
self._results = OrderedDict()
for prediction in predictions:
for k, v in prediction.items():
if k not in self._results:
self._results[k] = 0
self._results[k] += v
for k in self._results.keys():
self._results[k] /= len(predictions)
return copy.deepcopy(self._results)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from collections import OrderedDict
import numpy as np
from scipy.stats import pearsonr, spearmanr
from libai.utils import distributed as dist
from .evaluator import DatasetEvaluator
logger = logging.getLogger(__name__)
class RegEvaluator(DatasetEvaluator):
def __init__(self):
self._predictions = []
def reset(self):
self._predictions = []
def process(self, inputs, outputs):
pred_logits = outputs["prediction_scores"]
labels = inputs["labels"]
# measure accuracy
preds = pred_logits.cpu().topk(1)[1].squeeze(1).numpy()
labels = labels.cpu().numpy()
self._predictions.append({"preds": preds, "labels": labels})
def evaluate(self):
if not dist.is_main_process():
return {}
else:
predictions = self._predictions
preds = np.array([])
labels = np.array([])
for prediction in predictions:
preds = np.concatenate((preds, prediction["preds"]))
labels = np.concatenate((labels, prediction["labels"]))
pearson_corr = pearsonr(preds, labels)[0]
spearman_corr = spearmanr(preds, labels)[0]
corr = (pearson_corr + spearman_corr) / 2
self._results = OrderedDict()
self._results["pearson"] = pearson_corr
self._results["spearman"] = spearman_corr
self._results["corr"] = corr
return copy.deepcopy(self._results)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections.abc import Mapping
import oneflow as flow
from libai.utils import distributed as dist
def pad_batch(x_dict, batch_size, last_batch_lack, is_last_batch):
x = list(x_dict.values())[0]
tensor_batch = x.shape[0]
assert tensor_batch <= batch_size
if tensor_batch == batch_size and not is_last_batch:
return x_dict, batch_size
valid_sample = tensor_batch - last_batch_lack
data_parallel_size = dist.get_data_parallel_size()
assert tensor_batch % data_parallel_size == 0
tensor_micro_batch_size = tensor_batch // data_parallel_size
padded_dict = {}
for key, xi in x_dict.items():
pad_shape = (batch_size, *xi.shape[1:])
local_xi = xi.to_global(
sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cuda")
).to_local()
padded_xi = flow.zeros(pad_shape, dtype=xi.dtype, device="cuda")
padded_xi[:tensor_batch, ...] = padded_xi[:tensor_batch, ...] + local_xi
for i in range(last_batch_lack - 1):
start_idx = tensor_micro_batch_size * (data_parallel_size - i - 1) - 1
padded_xi[start_idx:-1] = padded_xi[start_idx + 1 :]
padded_xi = padded_xi.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]), placement=xi.placement
).to_global(sbp=xi.sbp)
padded_dict[key] = padded_xi
return padded_dict, valid_sample
def print_csv_format(results):
"""
Print main metrics in a particular format
so that they are easy to copypaste into a spreadsheet.
Args:
results (OrderedDict[dict]): task_name -> {metric -> score}
unordered dict can also be printed, but in arbitrary order
"""
assert isinstance(results, Mapping) or not len(results), results
logger = logging.getLogger(__name__)
for task, res in results.items():
if isinstance(res, Mapping):
# Don't print "AP-category" metrics since they are usually not tracked.
important_res = [(k, v) for k, v in res.items() if "-" not in k]
logger.info("copypaste: Task: {}".format(task))
logger.info("copypaste: " + ",".join([k[0] for k in important_res]))
logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res]))
else:
logger.info(f"copypaste: {task}={res}")
def flatten_results_dict(results):
"""
Expand a hierarchical dict of scalars into a flat dict of scalars.
If results[k1][k2][k3] = v, the returned dict will have the entry
{"k1/k2/k3": v}.
Args:
results (dict):
"""
r = {}
for k, v in results.items():
if isinstance(v, Mapping):
v = flatten_results_dict(v)
for kk, vv in v.items():
r[k + "/" + kk] = vv
else:
r[k] = v
return r
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABCMeta, abstractmethod
from typing import Any, Dict
import oneflow as flow
from libai.config import LazyConfig, try_get_key
from libai.engine import DefaultTrainer
from libai.utils import distributed as dist
from libai.utils.logger import setup_logger
logger = setup_logger(distributed_rank=dist.get_rank())
logger = logging.getLogger("libai.inference")
class BasePipeline(metaclass=ABCMeta):
"""
Base class for all task pipeline
"""
def __init__(
self,
config_file,
data_parallel=None,
tensor_parallel=None,
pipeline_parallel=None,
pipeline_stage_id=None,
pipeline_num_layers=None,
model_path=None,
mode="libai",
**kwargs,
):
# init cfg
self.cfg = LazyConfig.load(config_file)
flow.boxing.nccl.set_fusion_threshold_mbytes(
try_get_key(self.cfg, "train.nccl_fusion_threshold_mb", default=16)
)
flow.boxing.nccl.set_fusion_max_ops_num(
try_get_key(self.cfg, "train.nccl_fusion_max_ops", default=24)
)
self.update_cfg(
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
pipeline_num_layers,
)
dist.setup_dist_util(self.cfg.train.dist)
assert (
self.cfg.train.dist.data_parallel_size == 1
), "not support data parallel yet, only support tensor and pipeline parallel"
logger.info(self.cfg.train.dist)
# initial and load model
self.model = self.load_pretrain_weight(self.cfg.model, model_path, mode=mode)
self.model._apply(dist.convert_to_distributed_default_setting)
self.model = self.model.eval()
# initial tokenizer
if dist.is_main_process():
self.tokenizer = self.build_tokenizer(self.cfg)
else:
self.tokenizer = None
self.tokenizer = dist.broadcast_py_object(self.tokenizer, src=0)
# set parameters
(
self._preprocess_params,
self._forward_params,
self._postprocess_params,
) = self._parse_parameters(**kwargs)
def update_cfg(
self,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_stage_id=None,
pipeline_num_layers=None,
):
self.cfg.train.dist.data_parallel_size = data_parallel
self.cfg.train.dist.tensor_parallel_size = tensor_parallel
self.cfg.train.dist.pipeline_parallel_size = pipeline_parallel
self.cfg.train.dist.custom_pipeline_stage_id = pipeline_stage_id
if pipeline_num_layers is not None:
self.cfg.train.dist.pipeline_num_layers = pipeline_num_layers
if self.cfg.train.dist.pipeline_parallel_size > 1:
assert (
try_get_key(self.cfg.train.dist, "pipeline_num_layers") is not None
), "cfg.train.dist.pipeline_num_layers must be set when run pipeline parallel"
def load_pretrain_weight(
self,
libai_cfg_model,
model_path,
mode="libai",
):
"""load pretrained model.
Args:
libai_cfg_model (libai.models): Lazy config Model in Libai, you can import it
by `from libai.config.configs.common.models.bert
import pretrain_model as libai_cfg_model`
model_path (str): The directory path of pretrained model
mode (str): set it to `libai` for loading trained model from libai,
set it to `random` for quickly debugging by random initialized model
"""
if mode == "libai":
from libai.models.utils.model_utils.base_loader import ModelLoaderLiBai
model_loader = ModelLoaderLiBai(libai_cfg_model, libai_cfg_model.cfg, model_path)
model_loader.base_model_prefix_1 = None
model_loader.base_model_prefix_2 = ""
return model_loader.load()
elif mode == "random":
return DefaultTrainer.build_model(self.cfg)
else:
raise NotImplementedError
def build_tokenizer(self, cfg):
tokenizer = None
if try_get_key(cfg, "tokenization") is not None:
tokenizer = DefaultTrainer.build_tokenizer(cfg)
return tokenizer
@abstractmethod
def _parse_parameters(self, **pipeline_parameters):
raise NotImplementedError("_parse_parameters not implemented")
def __call__(self, inputs, *args, batch_size=None, **kwargs) -> dict:
preprocess_params, forward_params, postprocess_params = self._parse_parameters(
**kwargs
) # noqa
# Fuse __init__ params and __call__ params without modifying the __init__ ones.
preprocess_params = {**self._preprocess_params, **preprocess_params}
forward_params = {**self._forward_params, **forward_params}
postprocess_params = {**self._postprocess_params, **postprocess_params}
with flow.no_grad():
model_inputs_dict = self.preprocess(inputs, **preprocess_params)
model_outputs_dict = self.forward(model_inputs_dict, **forward_params)
model_outputs_dict = self.to_local(model_outputs_dict)
if dist.is_main_process():
outputs_dict = self.postprocess(model_outputs_dict, **postprocess_params)
else:
outputs_dict = {}
dist.synchronize()
return outputs_dict
def to_local(self, model_outputs_dict):
for key, value in model_outputs_dict.items():
if isinstance(value, flow.Tensor) and value.is_global:
model_outputs_dict[key] = dist.ttol(
value, ranks=[0] if value.placement.ranks.ndim == 1 else [[0]]
)
if flow.cuda.is_available():
dist.synchronize()
return model_outputs_dict
@abstractmethod
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> dict:
raise NotImplementedError("preprocess not implemented")
@abstractmethod
def forward(self, **kwargs: Dict) -> dict:
raise NotImplementedError("forward not implemented")
@abstractmethod
def postprocess(self, **kwargs: Dict) -> dict:
raise NotImplementedError("postprocess not implemented")
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and
# The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from abc import ABC, abstractmethod
from collections import UserDict
from typing import Optional, Tuple
import oneflow as flow
from libai.utils import distributed as dist
class BeamScorer(ABC):
@abstractmethod
def process(
self,
input_ids: flow.Tensor,
next_scores: flow.Tensor,
next_tokens: flow.Tensor,
next_indices: flow.Tensor,
**kwargs,
):
raise NotImplementedError("This is an abstract method.")
class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self) -> int:
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(
self, hyp: flow.Tensor, sum_logprobs: float, beam_indices: Optional[flow.Tensor] = None
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
class BeamSearchScorer(BeamScorer):
def __init__(
self,
batch_size: int,
num_beams: int,
length_penalty: Optional[float] = 1.0,
do_early_stopping: Optional[bool] = False,
num_beam_hyps_to_keep: Optional[int] = 1,
num_beam_groups: Optional[int] = 1,
**kwargs,
):
self.num_beams = num_beams
self.length_penalty = length_penalty
self.do_early_stopping = do_early_stopping
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
self.num_beam_groups = num_beam_groups
self.group_size = self.num_beams // self.num_beam_groups
self._is_init = False
self._beam_hyps = [
BeamHypotheses(
num_beams=self.num_beams,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
)
for _ in range(batch_size)
]
self._done = flow.tensor(
[False for _ in range(batch_size)],
dtype=flow.bool,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
if not isinstance(num_beams, int) or num_beams <= 1:
raise ValueError(
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}."
"For `num_beams` == 1, one should make use of `greedy_search` instead."
)
if (
not isinstance(num_beam_groups, int)
or (num_beam_groups > num_beams)
or (num_beams % num_beam_groups != 0)
):
raise ValueError(
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and "
f"`num_beams` has to be divisible by `num_beam_groups`, but is {num_beam_groups}"
f"with `num_beams` being {num_beams}."
)
if "max_length" in kwargs:
warnings.warn(
"Passing `max_length` to BeamSearchScorer is deprecated and has no effect. "
"`max_length` should be passed directly to `beam_search(...)`, `beam_sample(...)`"
", or `group_beam_search(...)`."
)
@property
def is_done(self) -> bool:
return self._done.all()
def process(
self,
input_ids: flow.Tensor,
next_scores: flow.Tensor,
next_tokens: flow.Tensor,
next_indices: flow.Tensor,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[flow.Tensor] = None,
) -> Tuple[flow.Tensor]:
cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group "
f"beam size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
next_beam_scores = flow.zeros(
(batch_size, self.group_size),
dtype=next_scores.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
next_beam_tokens = flow.zeros(
(batch_size, self.group_size),
dtype=next_tokens.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
next_beam_indices = flow.zeros(
(batch_size, self.group_size),
dtype=next_indices.dtype,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
raise ValueError(
f"Batch can only be done if at least {self.num_beams} beams have "
"been generated"
)
if eos_token_id is None or pad_token_id is None:
raise ValueError(
"Generated beams >= num_beams -> eos_token_id and pad_token have "
"to be defined"
)
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (next_token.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
else:
beam_index = None
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal "
f"to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} "
"are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
def finalize(
self,
input_ids: flow.Tensor,
final_beam_scores: flow.Tensor,
final_beam_tokens: flow.Tensor,
final_beam_indices: flow.Tensor,
max_length: int,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
beam_indices: Optional[flow.Tensor] = None,
):
batch_size = len(self._beam_hyps)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for beam_id in range(self.num_beams):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
# select the best hypotheses
sent_lengths = flow.zeros(
batch_size * self.num_beam_hyps_to_keep,
dtype=flow.long,
sbp=input_ids.sbp,
placement=input_ids.placement,
)
best = []
best_indices = []
best_scores = flow.zeros(
batch_size * self.num_beam_hyps_to_keep,
dtype=flow.float32,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
# retrieve best hypotheses
for i, beam_hyp in enumerate(self._beam_hyps):
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append hyp to lists
best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = (
min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
)
decoded = flow.zeros(
(batch_size * self.num_beam_hyps_to_keep, sent_max_len),
dtype=flow.long,
sbp=input_ids.sbp,
placement=input_ids.placement,
)
if len(best_indices) > 0 and best_indices[0] is not None:
indices = flow.zeros(
(batch_size * self.num_beam_hyps_to_keep, sent_max_len),
dtype=flow.long,
sbp=input_ids.sbp,
placement=input_ids.placement,
)
else:
indices = None
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id)
if indices is not None:
indices.fill_(-1)
# fill with hypotheses and eos_token_id if the latter fits in
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = flow.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id
return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and
# The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Callable, List, Tuple
import oneflow as flow
class LogitsProcessorList(list):
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor, **kwargs) -> flow.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} "
"for {processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
class NormalizationLogitsProcessor(object):
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
scores = scores.log_softmax(dim=-1)
return scores
class InfNanRemoveLogitsProcessor(object):
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
scores[scores != scores] = 0.0
scores[scores == float("inf")] = flow.finfo(scores.dtype).max
return scores
class ForcedEOSTokenLogitsProcessor(object):
def __init__(self, max_length: int, eos_token_id: int):
self.max_length = max_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
cur_len = input_ids.shape[-1]
if cur_len == self.max_length - 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.eos_token_id]] = -float("inf")
scores[:, self.eos_token_id] = 0
return scores
class ForcedBOSTokenLogitsProcessor(object):
def __init__(self, bos_token_id: int):
self.bos_token_id = bos_token_id
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
cur_len = input_ids.shape[-1]
if cur_len == 1:
num_tokens = scores.shape[1]
scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
scores[:, self.bos_token_id] = 0
return scores
class RepetitionPenaltyLogitsProcessor(object):
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
score = flow.gather(scores, 1, input_ids)
score = flow.where(score < 0, score * self.penalty, score / self.penalty)
scores = flow.scatter(scores, 1, input_ids, score)
return scores
class HammingDiversityLogitsProcessor(object):
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
self._diversity_penalty = diversity_penalty
if not isinstance(num_beams, int) or num_beams < 2:
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
self._num_beams = num_beams
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
if num_beam_groups > num_beams:
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
self._num_sub_beams = num_beams // num_beam_groups
def __call__(self, input_ids, scores, current_tokens, beam_group_idx) -> flow.Tensor:
scores = scores.numpy()
batch_size = current_tokens.shape[0] // self._num_beams
group_start_idx = beam_group_idx * self._num_sub_beams
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
group_size = group_end_idx - group_start_idx
vocab_size = scores.shape[-1]
if group_start_idx == 0:
return scores
for batch_idx in range(batch_size):
# predicted tokens of last time step of previous groups
previous_group_tokens = current_tokens[
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
]
token_frequency = flow.bincount(previous_group_tokens, minlength=vocab_size)
scores[batch_idx * group_size : (batch_idx + 1) * group_size] = (
scores[batch_idx * group_size : (batch_idx + 1) * group_size]
- self._diversity_penalty * token_frequency
)
return scores
def _get_ngrams(ngram_size: int, prev_input_ids: flow.Tensor, num_hypos: int):
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [
ngram[-1]
]
return generated_ngrams
def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
start_idx = cur_len + 1 - ngram_size
ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
return banned_ngrams.get(ngram_idx, [])
def _calc_banned_ngram_tokens(
ngram_size: int, prev_input_ids: flow.Tensor, num_hypos: int, cur_len: int
):
if cur_len + 1 < ngram_size:
return [[] for _ in range(num_hypos)]
generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
banned_tokens = [
_get_generated_ngrams(
generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len
)
for hypo_idx in range(num_hypos)
]
return banned_tokens
class NoRepeatNGramLogitsProcessor(object):
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(
f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}"
)
self.ngram_size = ngram_size
def __call__(self, input_ids, scores) -> flow.Tensor:
num_batch_hypotheses = scores.shape[0]
cur_len = input_ids.shape[-1]
banned_batch_tokens = _calc_banned_ngram_tokens(
self.ngram_size, input_ids, num_batch_hypotheses, cur_len
)
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
class EncoderNoRepeatNGramLogitsProcessor(object):
def __init__(self, encoder_ngram_size: int, encoder_input_ids: flow.Tensor):
if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
raise ValueError(
"`encoder_ngram_size` has to be a strictly positive integer, but is "
f"{encoder_ngram_size}"
)
self.ngram_size = encoder_ngram_size
if len(encoder_input_ids.shape) == 1:
encoder_input_ids = encoder_input_ids.unsqueeze(0)
self.batch_size = encoder_input_ids.shape[0]
self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
# B x num_beams
num_hypos = scores.shape[0]
num_beams = num_hypos // self.batch_size
cur_len = input_ids.shape[-1]
banned_batch_tokens = [
_get_generated_ngrams(
self.generated_ngrams[hypo_idx // num_beams],
input_ids[hypo_idx],
self.ngram_size,
cur_len,
)
for hypo_idx in range(num_hypos)
]
for i, banned_tokens in enumerate(banned_batch_tokens):
scores[i, banned_tokens] = -float("inf")
return scores
class MinLengthLogitsProcessor(object):
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
cur_len = input_ids.shape[-1]
if cur_len < self.min_length:
scores[:, self.eos_token_id] = -float("inf")
return scores
class PrefixConstrainedLogitsProcessor(object):
def __init__(
self, prefix_allowed_tokens_fn: Callable[[int, flow.Tensor], List[int]], num_beams: int
):
self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
self._num_beams = num_beams
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
mask = flow.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(
input_ids.view(-1, self._num_beams, input_ids.shape[-1])
):
for beam_id, sent in enumerate(beam_sent):
mask[
batch_id * self._num_beams + beam_id,
self._prefix_allowed_tokens_fn(batch_id, sent),
] = 0
return scores + mask
class ExponentialDecayLengthPenalty(object):
def __init__(
self, exponential_decay_length_penalty: Tuple, eos_token_id: int, input_ids_seq_length: int
):
self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
self.regulation_factor = exponential_decay_length_penalty[1]
self.eos_token_id = eos_token_id
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
cur_len = input_ids.shape[-1]
if cur_len > self.regulation_start:
scores[:, self.eos_token_id] = scores[:, self.eos_token_id] * pow(
self.regulation_factor, cur_len - self.regulation_start
)
return scores
class TemperatureLogitsWarper(object):
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
raise ValueError(
f"`temperature` has to be a strictly positive float, but is {temperature}"
)
self.temperature = temperature
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
scores = scores / self.temperature
return scores
class TopPLogitsWarper(object):
def __init__(
self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1
):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
self.top_p = top_p
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
sorted_logits, sorted_indices = flow.sort(scores, descending=True)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > self.top_p
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1
# because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep - 1] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = flow.scatter(
sorted_indices_to_remove, 1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TopKLogitsWarper(object):
def __init__(
self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1
):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
self.top_k = top_k
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
top_k = min(max(self.top_k, self.min_tokens_to_keep), scores.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = scores < flow.topk(scores, top_k)[0][..., -1, None]
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TypicalLogitsWarper(object):
def __init__(
self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1
):
mass = float(mass)
if not (mass > 0 and mass < 1):
raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
self.filter_value = filter_value
self.mass = mass
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> flow.Tensor:
# calculate entropy
normalized = flow.nn.functional.log_softmax(scores, dim=-1)
p = flow.exp(normalized)
ent = -flow.nansum(normalized * p, dim=-1, keepdim=True)
# shift and sort
shifted_scores = flow.abs((-normalized) - ent)
sorted_scores, sorted_indices = flow.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < self.mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if self.min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep
# (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
indices_to_remove = flow.scatter(
sorted_indices_to_remove, 1, sorted_indices, sorted_indices_to_remove
)
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and
# The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import warnings
from copy import deepcopy
import oneflow as flow
class StoppingCriteriaList(list):
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)
@property
def max_length(self):
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria):
return stopping_criterium.max_length
return None
class MaxLengthCriteria(object):
def __init__(self, max_length: int):
self.max_length = max_length
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor) -> bool:
return input_ids.shape[-1] >= self.max_length
class MaxTimeCriteria(object):
def __init__(self, max_time: float, initial_timestamp: float = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp
def __call__(self, input_ids: flow.Tensor, scores: flow.Tensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time
def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter",
UserWarning,
)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2020 The Google AI Language Team Authors, Facebook AI Research authors and
# The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
import oneflow as flow
from oneflow import nn
from libai.utils import distributed as dist
from .generation_beam_search import BeamScorer, BeamSearchScorer
from .generation_logits_processor import (
EncoderNoRepeatNGramLogitsProcessor,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor,
HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor,
LogitsProcessorList,
MinLengthLogitsProcessor,
NoRepeatNGramLogitsProcessor,
NormalizationLogitsProcessor,
PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
TypicalLogitsWarper,
)
from .generation_stopping_criteria import (
MaxLengthCriteria,
MaxTimeCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
logger = logging.getLogger(__name__)
class Generator:
def _prepare_model_inputs(
self,
inputs: Optional[flow.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, flow.Tensor]] = None,
):
if self.cfg.is_encoder_decoder:
input_name = "encoder_input_ids"
else:
input_name = "input_ids"
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside "
f"{input_name} which is not allowed."
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
if inputs is None:
inputs = self._prepare_input_ids_for_generation(
bos_token_id, model_kwargs.get("encoder_outputs", None)
)
return inputs, input_name, model_kwargs
def prepare_inputs_for_generation(self, input_ids: flow.Tensor, **kwargs):
"""
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the
generate method.
"""
return {"input_ids": input_ids}
def _prepare_input_ids_for_generation(
self, bos_token_id: Optional[int], encoder_outputs: Optional[flow.Tensor]
):
if self.cfg.is_encoder_decoder and encoder_outputs is not None:
shape = encoder_outputs.size()[:-1]
return (
flow.ones(
shape,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
* -100
)
if bos_token_id is None:
raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
return (
flow.ones(
(1, 1),
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
* bos_token_id
)
def _prepare_attention_mask_for_generation(
self,
inputs: flow.Tensor,
pad_token_id: Optional[int],
eos_token_id: Optional[int],
):
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [flow.int64, flow.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
# Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).bool()
else:
return flow.ones(
inputs.shape[:2],
dtype=flow.bool,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: flow.Tensor, model_kwargs, model_input_name: str
):
only_encoder = True
model_kwargs[model_input_name] = inputs_tensor
if "encoder_decoder_attn_mask" in set(inspect.signature(self.forward).parameters):
model_kwargs["encoder_decoder_attn_mask"] = model_kwargs["encoder_attn_mask"]
model_kwargs["encoder_outputs"] = self(**model_kwargs, only_encoder=only_encoder)
model_kwargs.pop(model_input_name)
return model_kwargs
def _prepare_decoder_input_ids_for_generation(
self,
batch_size: int,
decoder_start_token_id: int = None,
bos_token_id: int = None,
model_kwargs=None,
):
if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
return model_kwargs.pop("decoder_input_ids")
else:
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id
else self.cfg.decoder_start_token_id
)
return (
flow.ones(
(batch_size, 1),
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
* decoder_start_token_id
)
def _get_decoder_start_token_id(
self, decoder_start_token_id: int = None, bos_token_id: int = None
):
if decoder_start_token_id is not None:
return decoder_start_token_id
elif self.cfg.is_encoder_decoder:
return self.cfg.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
else:
return self.cfg.bos_token_idx
@staticmethod
def _expand_inputs_for_generation(
input_ids: flow.Tensor,
expand_size: int = 1,
is_encoder_decoder: bool = False,
attention_mask: Optional[flow.Tensor] = None,
encoder_outputs: Optional[flow.Tensor] = None,
**model_kwargs,
):
expanded_return_idx = (
flow.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1)
)
expanded_return_idx = expanded_return_idx.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
input_ids = input_ids.index_select(0, expanded_return_idx)
# token_type ids not supported.
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
if is_encoder_decoder:
if encoder_outputs is None:
raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs = encoder_outputs.to_global(placement=expanded_return_idx.placement)
encoder_outputs = encoder_outputs.index_select(0, expanded_return_idx)
model_kwargs["encoder_outputs"] = encoder_outputs
model_kwargs["encoder_attn_mask"] = model_kwargs["encoder_attn_mask"].index_select(
0, expanded_return_idx
)
model_kwargs["encoder_decoder_attn_mask"] = model_kwargs["encoder_attn_mask"]
return input_ids, model_kwargs
def _update_model_kwargs_for_generation(
self, outputs, model_kwargs, is_encoder_decoder: bool = False
):
if "past_key_values" in outputs:
model_kwargs["past"] = outputs["past_key_values"]
elif "mems" in outputs:
model_kwargs["past"] = outputs["mems"]
elif "past_buckets_states" in outputs:
model_kwargs["past"] = outputs["past_buckets_states"]
elif self.past_key_values[-1] is not None:
model_kwargs["past"] = self.past_key_values
else:
model_kwargs["past"] = None
# update attention mask
if "attention_mask" in model_kwargs and not is_encoder_decoder:
attention_mask = model_kwargs["attention_mask"]
pad = flow.ones(
(attention_mask.shape[0], 1),
sbp=attention_mask.sbp,
placement=attention_mask.placement,
)
model_kwargs["attention_mask"] = flow.cat([attention_mask, pad], dim=-1)
if "decoder_attn_mask" in model_kwargs and is_encoder_decoder:
attention_mask = model_kwargs["decoder_attn_mask"]
pad = flow.ones(
(attention_mask.shape[0], 1),
sbp=attention_mask.sbp,
placement=attention_mask.placement,
)
model_kwargs["decoder_attn_mask"] = flow.cat([attention_mask, pad], dim=-1)
return model_kwargs
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(
"Make sure that a `_reorder_cache` function is correctly implemented in "
f"{self.__class__.__module__} to enable beam search for {self.__class__}"
)
def _get_logits_warper(
self,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
temperature: Optional[float] = None,
num_beams: Optional[int] = None,
renormalize_logits: Optional[bool] = None,
):
# instantiate warpers list
warpers = LogitsProcessorList()
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(TemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(
TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1))
)
if top_p is not None and top_p < 1.0:
warpers.append(
TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))
)
if typical_p is not None and typical_p < 1.0:
warpers.append(
TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1))
)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
warpers.append(NormalizationLogitsProcessor())
return warpers
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
input_ids_seq_length: int,
encoder_input_ids: flow.Tensor,
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
prefix_allowed_tokens_fn: Callable[[int, flow.Tensor], List[int]],
num_beams: int,
num_beam_groups: int,
diversity_penalty: float,
remove_invalid_values: bool,
exponential_decay_length_penalty: Tuple,
logits_processor: Optional[LogitsProcessorList],
renormalize_logits: Optional[bool],
):
"""
This class returns a [`LogitsProcessorList`] list object that contains all relevant
[`LogitsProcessor`] instances used to modify the scores of the language model head.
"""
processors = LogitsProcessorList()
# instantiate processors list
if diversity_penalty is not None and diversity_penalty > 0.0:
processors.append(
HammingDiversityLogitsProcessor(
diversity_penalty=diversity_penalty,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
)
)
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
if self.cfg.is_encoder_decoder:
processors.append(
EncoderNoRepeatNGramLogitsProcessor(
encoder_no_repeat_ngram_size, encoder_input_ids
)
)
else:
raise ValueError(
"It's impossible to use `encoder_no_repeat_ngram_size` with decoder-only "
"architecture"
)
if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(MinLengthLogitsProcessor(min_length, eos_token_id))
if prefix_allowed_tokens_fn is not None:
processors.append(
PrefixConstrainedLogitsProcessor(
prefix_allowed_tokens_fn, num_beams // num_beam_groups
)
)
if forced_bos_token_id is not None:
processors.append(ForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(ForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if remove_invalid_values is True:
processors.append(InfNanRemoveLogitsProcessor())
if exponential_decay_length_penalty is not None:
processors.append(
ExponentialDecayLengthPenalty(
exponential_decay_length_penalty, eos_token_id, input_ids_seq_length
)
)
processors = self._merge_criteria_processor_list(processors, logits_processor)
# `LogitNormalization` should always be the last logit processor, when present
if renormalize_logits is True:
processors.append(NormalizationLogitsProcessor())
return processors
def _get_stopping_criteria(
self,
max_length: Optional[int],
max_time: Optional[float],
stopping_criteria: Optional[StoppingCriteriaList],
):
criteria = StoppingCriteriaList()
if max_length is not None:
criteria.append(MaxLengthCriteria(max_length=max_length))
if max_time is not None:
criteria.append(MaxTimeCriteria(max_time=max_time))
criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
return criteria
def _merge_criteria_processor_list(self, default_list, custom_list):
if len(custom_list) == 0:
return default_list
for default in default_list:
for custom in custom_list:
if type(custom) is type(default):
raise ValueError("Criteria repetition error.")
default_list.extend(custom_list)
return default_list
def compute_transition_beam_scores(
self,
sequences: flow.Tensor,
scores: Tuple[flow.Tensor],
beam_indices: flow.Tensor,
eos_token_id: int = None,
):
scores = flow.stack(scores).reshape(len(scores), -1).transpose(0, 1)
beam_indices_mask = beam_indices < 0
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_indices = beam_indices[:, :max_beam_length]
beam_indices_mask = beam_indices_mask[:, :max_beam_length]
beam_indices[beam_indices_mask] = 0
beam_sequence_indices = beam_indices * self.cfg.vocab_size
cut_idx = sequences.shape[-1] - max_beam_length
indices = sequences[:, cut_idx:] + beam_sequence_indices
transition_scores = scores.gather(0, indices)
transition_scores[beam_indices_mask] = 0
return transition_scores
def _validate_model_kwargs(self, model_kwargs):
if self.cfg.is_encoder_decoder:
for key in ["decoder_input_ids"]:
model_kwargs.pop(key, None)
unused_model_args = []
model_args = set(inspect.signature(self.prepare_inputs_for_generation).parameters)
if "kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters)
for key, value in model_kwargs.items():
if value is not None and key not in model_args:
unused_model_args.append(key)
if unused_model_args:
raise ValueError(
f"The following `model_kwargs` are not used by the model: {unused_model_args} "
"(note: typos in the generate arguments will also show up in this list)"
)
def greedy_search(
self,
input_ids: flow.Tensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
is_encoder_decoder: bool = False,
output_scores: bool = False,
**model_kwargs,
):
pad_token_id = pad_token_id if pad_token_id is not None else self.cfg.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.cfg.eos_token_id
output_scores = output_scores if output_scores is not None else self.cfg.output_scores
scores = () if output_scores else None
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use MaxLengthCriteria" " instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
# keep track of which sequences are already finished
unfinished_sequences = flow.ones(input_ids.shape[0])
cur_len = input_ids.shape[-1]
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# generate
outputs = self(**model_inputs)
next_token_logits = outputs["logits"][:, -1, :]
# logits_processor
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores
if output_scores:
scores += (next_token_scores,)
# argmax
next_tokens = flow.argmax(next_token_scores, dim=-1)
next_tokens = next_tokens.to_global(placement=input_ids.placement)
unfinished_sequences = unfinished_sequences.to_global(
sbp=next_tokens.sbp, placement=next_tokens.placement
)
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
next_tokens = next_tokens.to(flow.long)
input_ids = flow.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
)
cur_len = cur_len + 1
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id is not None:
unfinished_sequences = flow.mul(
unfinished_sequences, (next_tokens != eos_token_id).long()
)
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
# Release records
if "past_key_values" in self.__dir__():
self.past_key_values = [None] * self.cfg.hidden_layers
if "encoder_states" in self.__dir__():
self.encoder_states = None
return input_ids
def multinomial_sample(
self,
input_ids: flow.Tensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
is_encoder_decoder: bool = False,
output_scores: bool = False,
**model_kwargs,
):
# init values
pad_token_id = pad_token_id if pad_token_id is not None else self.cfg.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.cfg.eos_token_id
output_scores = output_scores if output_scores is not None else self.cfg.output_scores
scores = () if output_scores else None
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use "
"`stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
"instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
unfinished_sequences = flow.ones(input_ids.shape[0])
cur_len = input_ids.shape[-1]
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# generate
outputs = self(**model_inputs)
next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores
if output_scores:
scores += (next_token_scores,)
# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
probs = probs.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
).to_local()
next_tokens = flow.multinomial(probs, num_samples=1).squeeze(1)
next_tokens = next_tokens.to_global(
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
unfinished_sequences = unfinished_sequences.to_global(
sbp=next_tokens.sbp, placement=next_tokens.placement
)
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError(
"If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
)
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)
next_tokens = next_tokens.to(flow.long)
input_ids = flow.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
)
cur_len = cur_len + 1
if eos_token_id is not None:
unfinished_sequences = flow.mul(
unfinished_sequences, (next_tokens != eos_token_id).long()
)
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
break
# Release records
if "past_key_values" in self.__dir__():
self.past_key_values = [None] * self.cfg.hidden_layers
if "encoder_states" in self.__dir__():
self.encoder_states = None
return input_ids
def beam_search(
self,
input_ids: flow.Tensor,
beam_scorer: BeamScorer,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
is_encoder_decoder: bool = False,
output_scores: bool = False,
**model_kwargs,
):
pad_token_id = pad_token_id if pad_token_id is not None else self.cfg.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.cfg.eos_token_id
output_scores = output_scores if output_scores is not None else self.cfg.output_scores
scores = () if output_scores else None
logits_processor = (
logits_processor if logits_processor is not None else LogitsProcessorList()
)
stopping_criteria = (
stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
)
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use "
"`stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))`"
"instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
if len(stopping_criteria) == 0:
warnings.warn(
"You don't have defined any stopping_criteria, this will likely loop forever",
UserWarning,
)
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, "
f"but is {batch_beam_size}."
)
beam_indices = None
beam_scores = flow.zeros(
(batch_size, num_beams),
dtype=flow.float,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.placement("cuda", list(range(dist.get_world_size()))),
)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
outputs = self(**model_inputs)
next_token_logits = outputs["logits"][:, -1, :]
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores = next_token_scores.to_global(
sbp=input_ids.sbp, placement=input_ids.placement
)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores
)
# Store scores
if output_scores:
scores += (next_token_scores,)
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
next_token_scores, next_tokens = flow.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
)
next_indices = next_tokens // vocab_size
next_tokens = next_tokens % vocab_size
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
)
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = flow.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
)
# update past_key_value
if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(beam_idx)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, scores):
break
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
)
# Release records
if "past_key_values" in self.__dir__():
self.past_key_values = [None] * self.cfg.hidden_layers
if "encoder_states" in self.__dir__():
self.encoder_states = None
return sequence_outputs["sequences"]
@flow.no_grad()
def generate(
self,
inputs: Optional[flow.Tensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
early_stopping: Optional[bool] = None,
num_beams: Optional[int] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
typical_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, flow.Tensor], List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
renormalize_logits: Optional[bool] = None,
stopping_criteria=StoppingCriteriaList(),
constraints=None,
output_scores: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
**model_kwargs,
):
# 0. Validate model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# 1. Set generation parameters if not already defined
bos_token_id = bos_token_id if bos_token_id is not None else self.cfg.bos_token_id
num_beams = num_beams if num_beams is not None else self.cfg.num_beams
length_penalty = length_penalty if length_penalty is not None else self.cfg.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.cfg.early_stopping
num_beam_groups = (
num_beam_groups if num_beam_groups is not None else self.cfg.num_beam_groups
)
do_sample = do_sample if do_sample is not None else self.cfg.do_sample
num_return_sequences = (
num_return_sequences
if num_return_sequences is not None
else self.cfg.num_return_sequences
)
pad_token_id = pad_token_id if pad_token_id is not None else self.cfg.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.cfg.eos_token_id
output_scores = output_scores if output_scores is not None else self.cfg.output_scores
# 2. Prepare model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
# 3. Prepare other model kwargs
model_kwargs["use_cache"] = use_cache if use_cache is not None else self.cfg.use_cache
if self.cfg.is_encoder_decoder:
att_mask_name = "encoder_attn_mask"
accepts_attention_mask = att_mask_name in set(
inspect.signature(self.forward).parameters.keys()
)
else:
att_mask_name = "attention_mask"
accepts_attention_mask = att_mask_name in set(
inspect.signature(self.forward).parameters.keys()
)
requires_attention_mask = "encoder_outputs" not in model_kwargs
if (
model_kwargs.get(att_mask_name, None) is None
and requires_attention_mask
and accepts_attention_mask
):
model_kwargs[att_mask_name] = self._prepare_attention_mask_for_generation(
inputs_tensor, pad_token_id, eos_token_id
)
if self.cfg.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created
# and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name
)
# 4. Prepare `input_ids` which will be used for auto-regressive generation
if self.cfg.is_encoder_decoder:
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
model_kwargs=model_kwargs,
)
else:
# if decoder-only then inputs_tensor has to be `input_ids`
input_ids = inputs_tensor
# 5. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
if dist.is_main_process():
warnings.warn(
"Neither `max_length` nor `max_new_tokens` has been set, `max_length` will "
f"default to {self.cfg.max_length} (`self.cfg.max_length`). we recommend using"
" `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same"
)
# default to cfg if still None
max_length = max_length if max_length is not None else self.cfg.max_length
min_length = min_length if min_length is not None else self.cfg.min_length
if min_length is not None and min_length > max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than"
f"the maximum length ({max_length})"
)
if input_ids_seq_length >= max_length:
input_ids_string = "decoder_input_ids" if self.cfg.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is"
f" set to {max_length}. This can lead to unexpected behavior. You should consider "
"increasing `max_new_tokens`."
)
# 6. Determine generation mode
is_constraint_gen_mode = constraints is not None or force_words_ids is not None
is_greedy_gen_mode = (
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is False
and not is_constraint_gen_mode
)
is_sample_gen_mode = (
(num_beams == 1)
and (num_beam_groups == 1)
and do_sample is True
and not is_constraint_gen_mode
)
is_beam_gen_mode = (
(num_beams > 1)
and (num_beam_groups == 1)
and do_sample is False
and not is_constraint_gen_mode
)
# is_beam_sample_gen_mode = (
# (num_beams > 1)
# and (num_beam_groups == 1)
# and do_sample is True
# and not is_constraint_gen_mode
# )
is_group_beam_gen_mode = (
(num_beams > 1) and (num_beam_groups > 1) and not is_constraint_gen_mode
)
if num_beam_groups > num_beams:
raise ValueError("`num_beam_groups` has to be smaller or equal to `num_beams`")
if is_group_beam_gen_mode and do_sample is True:
raise ValueError(
"Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is"
" set to `False`."
)
# 7. Prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=inputs_tensor,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
diversity_penalty=diversity_penalty,
remove_invalid_values=remove_invalid_values,
exponential_decay_length_penalty=exponential_decay_length_penalty,
logits_processor=logits_processor,
renormalize_logits=renormalize_logits,
)
# 8. Prepare stopping criteria
stopping_criteria = self._get_stopping_criteria(
max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
)
# 9. Go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing"
" greedy search."
)
# 10. Run greedy search
return self.greedy_search(
input_ids,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
**model_kwargs,
)
elif is_sample_gen_mode:
# 10. Prepare logits warper
logits_warper = self._get_logits_warper(
top_k=top_k,
top_p=top_p,
typical_p=typical_p,
temperature=temperature,
num_beams=num_beams,
renormalize_logits=renormalize_logits,
)
# 11. Expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_return_sequences,
is_encoder_decoder=self.cfg.is_encoder_decoder,
**model_kwargs,
)
# 12. Run multinomial sample
return self.multinomial_sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
**model_kwargs,
)
elif is_beam_gen_mode:
if num_return_sequences > num_beams:
raise ValueError(
"`num_return_sequences` has to be smaller or equal to `num_beams`."
)
if stopping_criteria.max_length is None:
raise ValueError("`max_length` needs to be a stopping_criteria for now.")
# 10. Prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences,
)
# 11. Interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids,
expand_size=num_beams,
is_encoder_decoder=self.cfg.is_encoder_decoder,
**model_kwargs,
)
# 12. Run beam search
return self.beam_search(
input_ids,
beam_scorer,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
**model_kwargs,
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import oneflow as flow
from PIL import Image
from libai.config import instantiate
from libai.data.structures import DistTensorData, Instance
from libai.inference.basic import BasePipeline
class ImageClassificationPipeline(BasePipeline):
def __init__(
self,
config_file,
data_parallel=None,
tensor_parallel=None,
pipeline_parallel=None,
pipeline_stage_id=None,
pipeline_num_layers=None,
model_path=None,
mode="libai",
**kwargs,
):
super().__init__(
config_file,
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
pipeline_num_layers,
model_path,
mode,
**kwargs,
)
if "num_classes" in self.cfg.model:
self.num_classes = self.cfg.model.num_classes
elif "num_classes" in self.cfg.model.cfg:
self.num_classes = self.cfg.model.cfg.num_classes
else:
raise AttributeError("The model's config must contain num_classes")
label2id = self.label2id(self.num_classes)
self.id2label = {ind: label for label, ind in label2id.items()}
self.transform = instantiate(self.cfg.dataloader.test[0].dataset.transform)
def _parse_parameters(self, **pipeline_parameters):
preprocess_params = {}
forward_params = {}
postprocess_params = {**pipeline_parameters}
return preprocess_params, forward_params, postprocess_params
def preprocess(
self,
inputs,
**kwargs,
) -> dict:
assert os.path.exists(inputs), "inputs must be an existing image path!"
with open(inputs, "rb") as f:
img = Image.open(f).convert("RGB")
img = self.transform(img)
img = img.unsqueeze(0)
# to global tensor
model_input = Instance(
images=DistTensorData(img),
)
mdoel_input_dict = {}
for key, value in model_input.get_fields().items():
value.to_global()
mdoel_input_dict[key] = value.tensor
return mdoel_input_dict
def forward(self, mdoel_input_dict) -> dict:
model_outputs_dict = self.model(**mdoel_input_dict)
return model_outputs_dict
def postprocess(
self, model_outputs_dict, function_to_apply=None, return_all_scores=False, **kwargs
) -> dict:
# prepare
num_labels = self.num_classes
if function_to_apply is not None:
function_to_apply = function_to_apply.lower()
assert function_to_apply in [
"sigmoid",
"softmax",
"none",
], f"Unrecognized `function_to_apply` argument: {function_to_apply}"
else:
if num_labels == 1:
function_to_apply = "sigmoid"
elif num_labels > 1:
function_to_apply = "softmax"
# process, logits: [num_labels]
logits = model_outputs_dict["prediction_scores"][0]
if function_to_apply == "sigmoid":
scores = flow.sigmoid(logits)
elif function_to_apply == "softmax":
scores = flow.softmax(logits)
else:
scores = logits
scores = scores.detach().numpy()
if return_all_scores:
return [
{"label": self.id2label[i], "score": score.item()} for i, score in enumerate(scores)
]
else:
return {
"label": self.id2label[scores.argmax().item()],
"score": scores.max().item(),
}
def label2id(self, num_classes):
"""
Args:
num_classes (int): the number of total classes
Returns:
labels (list): a dict contains all the labels for inference,
each item should be the form as follows:
{
"tench": 0,
"tiger": 1,
"xxx", n,
}
"""
from libai.inference.utils.imagenet_class import IMAGENET_LABELS as labels
assert num_classes == len(labels), "number of labels must be equal to num_classes"
return {label: i for (i, label) in enumerate(labels)}
if __name__ == "__main__":
pipeline = ImageClassificationPipeline("/home/chengpeng/config.yaml", 1, 1, 1)
print(pipeline("data_test/inference_test_data/ILSVRC2012_val_00000293.JPEG"))
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import oneflow as flow
from libai.data.structures import DistTensorData, Instance
from libai.inference.basic import BasePipeline
class TextClassificationPipeline(BasePipeline):
def __init__(
self,
config_file,
data_parallel=None,
tensor_parallel=None,
pipeline_parallel=None,
pipeline_stage_id=None,
pipeline_num_layers=None,
model_path=None,
mode="libai",
**kwargs,
):
super().__init__(
config_file,
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
model_path,
pipeline_num_layers,
mode,
**kwargs,
)
def update_cfg(
self,
data_parallel=1,
tensor_parallel=1,
pipeline_parallel=1,
pipeline_stage_id=None,
pipeline_num_layers=None,
):
super().update_cfg(
data_parallel,
tensor_parallel,
pipeline_parallel,
pipeline_stage_id,
pipeline_num_layers,
)
self.cfg.model.cfg.hidden_dropout_prob = 0.0
self.cfg.model.cfg.attention_probs_dropout_prob = 0.0
assert "num_labels" in self.cfg.model.cfg, "The model's config must contain num_labels"
if "label2id" not in self.cfg.model.cfg:
label2id = {"Label_" + str(i): i for i in range(self.cfg.model.cfg.num_labels)}
id2label = {ind: label for label, ind in label2id.items()}
self.cfg.model.cfg["label2id"] = label2id
self.cfg.model.cfg["id2label"] = id2label
def _parse_parameters(self, **pipeline_parameters):
preprocess_params = {}
forward_params = {}
postprocess_params = {**pipeline_parameters}
return preprocess_params, forward_params, postprocess_params
def preprocess(
self,
inputs,
pad: bool = False,
**kwargs,
) -> dict:
# tokenizer encoder
input_ids = flow.tensor(np.array(self.tokenizer.encode(inputs)))
padding_mask = flow.tensor(np.ones(input_ids.shape), dtype=flow.bool)
# set batch size = 1
input_ids = input_ids.unsqueeze(0)
padding_mask = padding_mask.unsqueeze(0)
# to global tensor
model_input = Instance(
input_ids=DistTensorData(input_ids),
attention_mask=DistTensorData(padding_mask),
)
mdoel_input_dict = {}
for key, value in model_input.get_fields().items():
value.to_global()
mdoel_input_dict[key] = value.tensor
return mdoel_input_dict
def forward(self, mdoel_input_dict) -> dict:
model_outputs_dict = self.model(**mdoel_input_dict)
return model_outputs_dict
def postprocess(
self, model_outputs_dict, function_to_apply=None, return_all_scores=False, **kwargs
) -> dict:
# prepare
num_labels = self.cfg.model.cfg.num_labels
if function_to_apply is not None:
function_to_apply = function_to_apply.lower()
assert function_to_apply in [
"sigmoid",
"softmax",
"none",
], f"Unrecognized `function_to_apply` argument: {function_to_apply}"
else:
if num_labels == 1:
function_to_apply = "sigmoid"
elif num_labels > 1:
function_to_apply = "softmax"
# process, logits: [num_labels]
logits = model_outputs_dict["logits"][0]
if function_to_apply == "sigmoid":
scores = flow.sigmoid(logits)
elif function_to_apply == "softmax":
scores = flow.softmax(logits)
else:
scores = logits
scores = scores.detach().numpy()
if return_all_scores:
return [
{"label": self.cfg.model.cfg.id2label[i], "score": score.item()}
for i, score in enumerate(scores)
]
else:
return {
"label": self.cfg.model.cfg.id2label[scores.argmax().item()],
"score": scores.max().item(),
}
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from libai.inference.basic import BasePipeline
from libai.utils import distributed as dist
class TextGenerationPipeline(BasePipeline):
def load_pretrain_weight(self, libai_cfg_model, model_path, mode="huggingface"):
"""load pretrained model.
Args:
libai_cfg_model (libai.models): Lazy config Model in Libai, you can import it
by `from libai.config.configs.common.models.bert
import pretrain_model as libai_cfg_model`
model_path (str): The directory path of pretrained model,
"""
if mode == "huggingface":
from projects.MT5.utils.mt5_loader import T5LoaderHuggerFace
model_loader = T5LoaderHuggerFace(
libai_cfg_model,
libai_cfg_model.cfg,
model_path,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
embedding_dropout_prob=0.0,
)
return model_loader.load()
elif mode == "libai":
from projects.MT5.utils.mt5_loader import T5LoaderLibai
model_loader = T5LoaderLibai(
libai_cfg_model,
libai_cfg_model.cfg,
model_path,
)
return model_loader.load()
elif mode == "random":
from libai.engine import DefaultTrainer
return DefaultTrainer.build_model(self.cfg)
else:
raise NotImplementedError
def _parse_parameters(self, **pipeline_parameters):
preprocess_params = {}
forward_params = {**pipeline_parameters}
postprocess_params = {}
return preprocess_params, forward_params, postprocess_params
def preprocess(
self,
inputs,
pad: bool = False,
**kwargs,
) -> dict:
# tokenizer encoder
encoder_ids = self.tokenizer.encode(inputs, return_tensors="of", is_global=True)
encoder_input_dict = {
"encoder_ids": encoder_ids,
}
return encoder_input_dict
def forward(self, encoder_input_dict, **kwargs) -> dict:
outputs = self.model.generate(encoder_input_dict["encoder_ids"], **kwargs)
return {"return_ids": outputs}
def postprocess(self, model_output_dict, **kwargs) -> dict:
return_ids = model_output_dict["return_ids"]
records = [
{"generated_text": self.tokenizer.decode(return_ids[i], skip_special_tokens=True)}
for i in range(return_ids.size(0))
]
return records
if __name__ == "__main__":
pipeline = TextGenerationPipeline(
"/path/to/libai/projects/MT5/configs/t5_inference.py",
data_parallel=1,
tensor_parallel=2,
pipeline_parallel=2,
pipeline_stage_id=[0] * 12 + [1] * 12,
pipeline_num_layers=12 * 2,
model_path="/path/to/t5-base",
mode="huggingface",
)
text = ["summarize: She is a student, She is tall, She loves study"]
dict1 = pipeline(text)
if dist.is_main_process():
print(dict1)
IMAGENET_LABELS = [
"tench, Tinca tinca",
"goldfish, Carassius auratus",
"great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias", # noqa: E501
"tiger shark, Galeocerdo cuvieri",
"hammerhead, hammerhead shark",
"electric ray, crampfish, numbfish, torpedo",
"stingray",
"cock",
"hen",
"ostrich, Struthio camelus",
"brambling, Fringilla montifringilla",
"goldfinch, Carduelis carduelis",
"house finch, linnet, Carpodacus mexicanus",
"junco, snowbird",
"indigo bunting, indigo finch, indigo bird, Passerina cyanea",
"robin, American robin, Turdus migratorius",
"bulbul",
"jay",
"magpie",
"chickadee",
"water ouzel, dipper",
"kite",
"bald eagle, American eagle, Haliaeetus leucocephalus",
"vulture",
"great grey owl, great gray owl, Strix nebulosa",
"European fire salamander, Salamandra salamandra",
"common newt, Triturus vulgaris",
"eft",
"spotted salamander, Ambystoma maculatum",
"axolotl, mud puppy, Ambystoma mexicanum",
"bullfrog, Rana catesbeiana",
"tree frog, tree-frog",
"tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
"loggerhead, loggerhead turtle, Caretta caretta",
"leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea", # noqa: E501
"mud turtle",
"terrapin",
"box turtle, box tortoise",
"banded gecko",
"common iguana, iguana, Iguana iguana",
"American chameleon, anole, Anolis carolinensis",
"whiptail, whiptail lizard",
"agama",
"frilled lizard, Chlamydosaurus kingi",
"alligator lizard",
"Gila monster, Heloderma suspectum",
"green lizard, Lacerta viridis",
"African chameleon, Chamaeleo chamaeleon",
"Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis", # noqa: E501
"African crocodile, Nile crocodile, Crocodylus niloticus",
"American alligator, Alligator mississipiensis",
"triceratops",
"thunder snake, worm snake, Carphophis amoenus",
"ringneck snake, ring-necked snake, ring snake",
"hognose snake, puff adder, sand viper",
"green snake, grass snake",
"king snake, kingsnake",
"garter snake, grass snake",
"water snake",
"vine snake",
"night snake, Hypsiglena torquata",
"boa constrictor, Constrictor constrictor",
"rock python, rock snake, Python sebae",
"Indian cobra, Naja naja",
"green mamba",
"sea snake",
"horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
"diamondback, diamondback rattlesnake, Crotalus adamanteus",
"sidewinder, horned rattlesnake, Crotalus cerastes",
"trilobite",
"harvestman, daddy longlegs, Phalangium opilio",
"scorpion",
"black and gold garden spider, Argiope aurantia",
"barn spider, Araneus cavaticus",
"garden spider, Aranea diademata",
"black widow, Latrodectus mactans",
"tarantula",
"wolf spider, hunting spider",
"tick",
"centipede",
"black grouse",
"ptarmigan",
"ruffed grouse, partridge, Bonasa umbellus",
"prairie chicken, prairie grouse, prairie fowl",
"peacock",
"quail",
"partridge",
"African grey, African gray, Psittacus erithacus",
"macaw",
"sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
"lorikeet",
"coucal",
"bee eater",
"hornbill",
"hummingbird",
"jacamar",
"toucan",
"drake",
"red-breasted merganser, Mergus serrator",
"goose",
"black swan, Cygnus atratus",
"tusker",
"echidna, spiny anteater, anteater",
"platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus", # noqa: E501
"wallaby, brush kangaroo",
"koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus", # noqa: E501
"wombat",
"jellyfish",
"sea anemone, anemone",
"brain coral",
"flatworm, platyhelminth",
"nematode, nematode worm, roundworm",
"conch",
"snail",
"slug",
"sea slug, nudibranch",
"chiton, coat-of-mail shell, sea cradle, polyplacophore",
"chambered nautilus, pearly nautilus, nautilus",
"Dungeness crab, Cancer magister",
"rock crab, Cancer irroratus",
"fiddler crab",
"king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica", # noqa: E501
"American lobster, Northern lobster, Maine lobster, Homarus americanus", # noqa: E501
"spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish", # noqa: E501
"crayfish, crawfish, crawdad, crawdaddy",
"hermit crab",
"isopod",
"white stork, Ciconia ciconia",
"black stork, Ciconia nigra",
"spoonbill",
"flamingo",
"little blue heron, Egretta caerulea",
"American egret, great white heron, Egretta albus",
"bittern",
"crane",
"limpkin, Aramus pictus",
"European gallinule, Porphyrio porphyrio",
"American coot, marsh hen, mud hen, water hen, Fulica americana",
"bustard",
"ruddy turnstone, Arenaria interpres",
"red-backed sandpiper, dunlin, Erolia alpina",
"redshank, Tringa totanus",
"dowitcher",
"oystercatcher, oyster catcher",
"pelican",
"king penguin, Aptenodytes patagonica",
"albatross, mollymawk",
"grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus", # noqa: E501
"killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
"dugong, Dugong dugon",
"sea lion",
"Chihuahua",
"Japanese spaniel",
"Maltese dog, Maltese terrier, Maltese",
"Pekinese, Pekingese, Peke",
"Shih-Tzu",
"Blenheim spaniel",
"papillon",
"toy terrier",
"Rhodesian ridgeback",
"Afghan hound, Afghan",
"basset, basset hound",
"beagle",
"bloodhound, sleuthhound",
"bluetick",
"black-and-tan coonhound",
"Walker hound, Walker foxhound",
"English foxhound",
"redbone",
"borzoi, Russian wolfhound",
"Irish wolfhound",
"Italian greyhound",
"whippet",
"Ibizan hound, Ibizan Podenco",
"Norwegian elkhound, elkhound",
"otterhound, otter hound",
"Saluki, gazelle hound",
"Scottish deerhound, deerhound",
"Weimaraner",
"Staffordshire bullterrier, Staffordshire bull terrier",
"American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier", # noqa: E501
"Bedlington terrier",
"Border terrier",
"Kerry blue terrier",
"Irish terrier",
"Norfolk terrier",
"Norwich terrier",
"Yorkshire terrier",
"wire-haired fox terrier",
"Lakeland terrier",
"Sealyham terrier, Sealyham",
"Airedale, Airedale terrier",
"cairn, cairn terrier",
"Australian terrier",
"Dandie Dinmont, Dandie Dinmont terrier",
"Boston bull, Boston terrier",
"miniature schnauzer",
"giant schnauzer",
"standard schnauzer",
"Scotch terrier, Scottish terrier, Scottie",
"Tibetan terrier, chrysanthemum dog",
"silky terrier, Sydney silky",
"soft-coated wheaten terrier",
"West Highland white terrier",
"Lhasa, Lhasa apso",
"flat-coated retriever",
"curly-coated retriever",
"golden retriever",
"Labrador retriever",
"Chesapeake Bay retriever",
"German short-haired pointer",
"vizsla, Hungarian pointer",
"English setter",
"Irish setter, red setter",
"Gordon setter",
"Brittany spaniel",
"clumber, clumber spaniel",
"English springer, English springer spaniel",
"Welsh springer spaniel",
"cocker spaniel, English cocker spaniel, cocker",
"Sussex spaniel",
"Irish water spaniel",
"kuvasz",
"schipperke",
"groenendael",
"malinois",
"briard",
"kelpie",
"komondor",
"Old English sheepdog, bobtail",
"Shetland sheepdog, Shetland sheep dog, Shetland",
"collie",
"Border collie",
"Bouvier des Flandres, Bouviers des Flandres",
"Rottweiler",
"German shepherd, German shepherd dog, German police dog, alsatian",
"Doberman, Doberman pinscher",
"miniature pinscher",
"Greater Swiss Mountain dog",
"Bernese mountain dog",
"Appenzeller",
"EntleBucher",
"boxer",
"bull mastiff",
"Tibetan mastiff",
"French bulldog",
"Great Dane",
"Saint Bernard, St Bernard",
"Eskimo dog, husky",
"malamute, malemute, Alaskan malamute",
"Siberian husky",
"dalmatian, coach dog, carriage dog",
"affenpinscher, monkey pinscher, monkey dog",
"basenji",
"pug, pug-dog",
"Leonberg",
"Newfoundland, Newfoundland dog",
"Great Pyrenees",
"Samoyed, Samoyede",
"Pomeranian",
"chow, chow chow",
"keeshond",
"Brabancon griffon",
"Pembroke, Pembroke Welsh corgi",
"Cardigan, Cardigan Welsh corgi",
"toy poodle",
"miniature poodle",
"standard poodle",
"Mexican hairless",
"timber wolf, grey wolf, gray wolf, Canis lupus",
"white wolf, Arctic wolf, Canis lupus tundrarum",
"red wolf, maned wolf, Canis rufus, Canis niger",
"coyote, prairie wolf, brush wolf, Canis latrans",
"dingo, warrigal, warragal, Canis dingo",
"dhole, Cuon alpinus",
"African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
"hyena, hyaena",
"red fox, Vulpes vulpes",
"kit fox, Vulpes macrotis",
"Arctic fox, white fox, Alopex lagopus",
"grey fox, gray fox, Urocyon cinereoargenteus",
"tabby, tabby cat",
"tiger cat",
"Persian cat",
"Siamese cat, Siamese",
"Egyptian cat",
"cougar, puma, catamount, mountain lion, painter, panther, Felis concolor", # noqa: E501
"lynx, catamount",
"leopard, Panthera pardus",
"snow leopard, ounce, Panthera uncia",
"jaguar, panther, Panthera onca, Felis onca",
"lion, king of beasts, Panthera leo",
"tiger, Panthera tigris",
"cheetah, chetah, Acinonyx jubatus",
"brown bear, bruin, Ursus arctos",
"American black bear, black bear, Ursus americanus, Euarctos americanus", # noqa: E501
"ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
"sloth bear, Melursus ursinus, Ursus ursinus",
"mongoose",
"meerkat, mierkat",
"tiger beetle",
"ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
"ground beetle, carabid beetle",
"long-horned beetle, longicorn, longicorn beetle",
"leaf beetle, chrysomelid",
"dung beetle",
"rhinoceros beetle",
"weevil",
"fly",
"bee",
"ant, emmet, pismire",
"grasshopper, hopper",
"cricket",
"walking stick, walkingstick, stick insect",
"cockroach, roach",
"mantis, mantid",
"cicada, cicala",
"leafhopper",
"lacewing, lacewing fly",
"dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk", # noqa: E501
"damselfly",
"admiral",
"ringlet, ringlet butterfly",
"monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
"cabbage butterfly",
"sulphur butterfly, sulfur butterfly",
"lycaenid, lycaenid butterfly",
"starfish, sea star",
"sea urchin",
"sea cucumber, holothurian",
"wood rabbit, cottontail, cottontail rabbit",
"hare",
"Angora, Angora rabbit",
"hamster",
"porcupine, hedgehog",
"fox squirrel, eastern fox squirrel, Sciurus niger",
"marmot",
"beaver",
"guinea pig, Cavia cobaya",
"sorrel",
"zebra",
"hog, pig, grunter, squealer, Sus scrofa",
"wild boar, boar, Sus scrofa",
"warthog",
"hippopotamus, hippo, river horse, Hippopotamus amphibius",
"ox",
"water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
"bison",
"ram, tup",
"bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis", # noqa: E501
"ibex, Capra ibex",
"hartebeest",
"impala, Aepyceros melampus",
"gazelle",
"Arabian camel, dromedary, Camelus dromedarius",
"llama",
"weasel",
"mink",
"polecat, fitch, foulmart, foumart, Mustela putorius",
"black-footed ferret, ferret, Mustela nigripes",
"otter",
"skunk, polecat, wood pussy",
"badger",
"armadillo",
"three-toed sloth, ai, Bradypus tridactylus",
"orangutan, orang, orangutang, Pongo pygmaeus",
"gorilla, Gorilla gorilla",
"chimpanzee, chimp, Pan troglodytes",
"gibbon, Hylobates lar",
"siamang, Hylobates syndactylus, Symphalangus syndactylus",
"guenon, guenon monkey",
"patas, hussar monkey, Erythrocebus patas",
"baboon",
"macaque",
"langur",
"colobus, colobus monkey",
"proboscis monkey, Nasalis larvatus",
"marmoset",
"capuchin, ringtail, Cebus capucinus",
"howler monkey, howler",
"titi, titi monkey",
"spider monkey, Ateles geoffroyi",
"squirrel monkey, Saimiri sciureus",
"Madagascar cat, ring-tailed lemur, Lemur catta",
"indri, indris, Indri indri, Indri brevicaudatus",
"Indian elephant, Elephas maximus",
"African elephant, Loxodonta africana",
"lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
"giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
"barracouta, snoek",
"eel",
"coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch", # noqa: E501
"rock beauty, Holocanthus tricolor",
"anemone fish",
"sturgeon",
"gar, garfish, garpike, billfish, Lepisosteus osseus",
"lionfish",
"puffer, pufferfish, blowfish, globefish",
"abacus",
"abaya",
"academic gown, academic robe, judge's robe",
"accordion, piano accordion, squeeze box",
"acoustic guitar",
"aircraft carrier, carrier, flattop, attack aircraft carrier",
"airliner",
"airship, dirigible",
"altar",
"ambulance",
"amphibian, amphibious vehicle",
"analog clock",
"apiary, bee house",
"apron",
"ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin", # noqa: E501
"assault rifle, assault gun",
"backpack, back pack, knapsack, packsack, rucksack, haversack",
"bakery, bakeshop, bakehouse",
"balance beam, beam",
"balloon",
"ballpoint, ballpoint pen, ballpen, Biro",
"Band Aid",
"banjo",
"bannister, banister, balustrade, balusters, handrail",
"barbell",
"barber chair",
"barbershop",
"barn",
"barometer",
"barrel, cask",
"barrow, garden cart, lawn cart, wheelbarrow",
"baseball",
"basketball",
"bassinet",
"bassoon",
"bathing cap, swimming cap",
"bath towel",
"bathtub, bathing tub, bath, tub",
"beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon", # noqa: E501
"beacon, lighthouse, beacon light, pharos",
"beaker",
"bearskin, busby, shako",
"beer bottle",
"beer glass",
"bell cote, bell cot",
"bib",
"bicycle-built-for-two, tandem bicycle, tandem",
"bikini, two-piece",
"binder, ring-binder",
"binoculars, field glasses, opera glasses",
"birdhouse",
"boathouse",
"bobsled, bobsleigh, bob",
"bolo tie, bolo, bola tie, bola",
"bonnet, poke bonnet",
"bookcase",
"bookshop, bookstore, bookstall",
"bottlecap",
"bow",
"bow tie, bow-tie, bowtie",
"brass, memorial tablet, plaque",
"brassiere, bra, bandeau",
"breakwater, groin, groyne, mole, bulwark, seawall, jetty",
"breastplate, aegis, egis",
"broom",
"bucket, pail",
"buckle",
"bulletproof vest",
"bullet train, bullet",
"butcher shop, meat market",
"cab, hack, taxi, taxicab",
"caldron, cauldron",
"candle, taper, wax light",
"cannon",
"canoe",
"can opener, tin opener",
"cardigan",
"car mirror",
"carousel, carrousel, merry-go-round, roundabout, whirligig",
"carpenter's kit, tool kit",
"carton",
"car wheel",
"cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM", # noqa: E501
"cassette",
"cassette player",
"castle",
"catamaran",
"CD player",
"cello, violoncello",
"cellular telephone, cellular phone, cellphone, cell, mobile phone",
"chain",
"chainlink fence",
"chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour", # noqa: E501
"chain saw, chainsaw",
"chest",
"chiffonier, commode",
"chime, bell, gong",
"china cabinet, china closet",
"Christmas stocking",
"church, church building",
"cinema, movie theater, movie theatre, movie house, picture palace",
"cleaver, meat cleaver, chopper",
"cliff dwelling",
"cloak",
"clog, geta, patten, sabot",
"cocktail shaker",
"coffee mug",
"coffeepot",
"coil, spiral, volute, whorl, helix",
"combination lock",
"computer keyboard, keypad",
"confectionery, confectionary, candy store",
"container ship, containership, container vessel",
"convertible",
"corkscrew, bottle screw",
"cornet, horn, trumpet, trump",
"cowboy boot",
"cowboy hat, ten-gallon hat",
"cradle",
"crane",
"crash helmet",
"crate",
"crib, cot",
"Crock Pot",
"croquet ball",
"crutch",
"cuirass",
"dam, dike, dyke",
"desk",
"desktop computer",
"dial telephone, dial phone",
"diaper, nappy, napkin",
"digital clock",
"digital watch",
"dining table, board",
"dishrag, dishcloth",
"dishwasher, dish washer, dishwashing machine",
"disk brake, disc brake",
"dock, dockage, docking facility",
"dogsled, dog sled, dog sleigh",
"dome",
"doormat, welcome mat",
"drilling platform, offshore rig",
"drum, membranophone, tympan",
"drumstick",
"dumbbell",
"Dutch oven",
"electric fan, blower",
"electric guitar",
"electric locomotive",
"entertainment center",
"envelope",
"espresso maker",
"face powder",
"feather boa, boa",
"file, file cabinet, filing cabinet",
"fireboat",
"fire engine, fire truck",
"fire screen, fireguard",
"flagpole, flagstaff",
"flute, transverse flute",
"folding chair",
"football helmet",
"forklift",
"fountain",
"fountain pen",
"four-poster",
"freight car",
"French horn, horn",
"frying pan, frypan, skillet",
"fur coat",
"garbage truck, dustcart",
"gasmask, respirator, gas helmet",
"gas pump, gasoline pump, petrol pump, island dispenser",
"goblet",
"go-kart",
"golf ball",
"golfcart, golf cart",
"gondola",
"gong, tam-tam",
"gown",
"grand piano, grand",
"greenhouse, nursery, glasshouse",
"grille, radiator grille",
"grocery store, grocery, food market, market",
"guillotine",
"hair slide",
"hair spray",
"half track",
"hammer",
"hamper",
"hand blower, blow dryer, blow drier, hair dryer, hair drier",
"hand-held computer, hand-held microcomputer",
"handkerchief, hankie, hanky, hankey",
"hard disc, hard disk, fixed disk",
"harmonica, mouth organ, harp, mouth harp",
"harp",
"harvester, reaper",
"hatchet",
"holster",
"home theater, home theatre",
"honeycomb",
"hook, claw",
"hoopskirt, crinoline",
"horizontal bar, high bar",
"horse cart, horse-cart",
"hourglass",
"iPod",
"iron, smoothing iron",
"jack-o'-lantern",
"jean, blue jean, denim",
"jeep, landrover",
"jersey, T-shirt, tee shirt",
"jigsaw puzzle",
"jinrikisha, ricksha, rickshaw",
"joystick",
"kimono",
"knee pad",
"knot",
"lab coat, laboratory coat",
"ladle",
"lampshade, lamp shade",
"laptop, laptop computer",
"lawn mower, mower",
"lens cap, lens cover",
"letter opener, paper knife, paperknife",
"library",
"lifeboat",
"lighter, light, igniter, ignitor",
"limousine, limo",
"liner, ocean liner",
"lipstick, lip rouge",
"Loafer",
"lotion",
"loudspeaker, speaker, speaker unit, loudspeaker system, speaker system", # noqa: E501
"loupe, jeweler's loupe",
"lumbermill, sawmill",
"magnetic compass",
"mailbag, postbag",
"mailbox, letter box",
"maillot",
"maillot, tank suit",
"manhole cover",
"maraca",
"marimba, xylophone",
"mask",
"matchstick",
"maypole",
"maze, labyrinth",
"measuring cup",
"medicine chest, medicine cabinet",
"megalith, megalithic structure",
"microphone, mike",
"microwave, microwave oven",
"military uniform",
"milk can",
"minibus",
"miniskirt, mini",
"minivan",
"missile",
"mitten",
"mixing bowl",
"mobile home, manufactured home",
"Model T",
"modem",
"monastery",
"monitor",
"moped",
"mortar",
"mortarboard",
"mosque",
"mosquito net",
"motor scooter, scooter",
"mountain bike, all-terrain bike, off-roader",
"mountain tent",
"mouse, computer mouse",
"mousetrap",
"moving van",
"muzzle",
"nail",
"neck brace",
"necklace",
"nipple",
"notebook, notebook computer",
"obelisk",
"oboe, hautboy, hautbois",
"ocarina, sweet potato",
"odometer, hodometer, mileometer, milometer",
"oil filter",
"organ, pipe organ",
"oscilloscope, scope, cathode-ray oscilloscope, CRO",
"overskirt",
"oxcart",
"oxygen mask",
"packet",
"paddle, boat paddle",
"paddlewheel, paddle wheel",
"padlock",
"paintbrush",
"pajama, pyjama, pj's, jammies",
"palace",
"panpipe, pandean pipe, syrinx",
"paper towel",
"parachute, chute",
"parallel bars, bars",
"park bench",
"parking meter",
"passenger car, coach, carriage",
"patio, terrace",
"pay-phone, pay-station",
"pedestal, plinth, footstall",
"pencil box, pencil case",
"pencil sharpener",
"perfume, essence",
"Petri dish",
"photocopier",
"pick, plectrum, plectron",
"pickelhaube",
"picket fence, paling",
"pickup, pickup truck",
"pier",
"piggy bank, penny bank",
"pill bottle",
"pillow",
"ping-pong ball",
"pinwheel",
"pirate, pirate ship",
"pitcher, ewer",
"plane, carpenter's plane, woodworking plane",
"planetarium",
"plastic bag",
"plate rack",
"plow, plough",
"plunger, plumber's helper",
"Polaroid camera, Polaroid Land camera",
"pole",
"police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria", # noqa: E501
"poncho",
"pool table, billiard table, snooker table",
"pop bottle, soda bottle",
"pot, flowerpot",
"potter's wheel",
"power drill",
"prayer rug, prayer mat",
"printer",
"prison, prison house",
"projectile, missile",
"projector",
"puck, hockey puck",
"punching bag, punch bag, punching ball, punchball",
"purse",
"quill, quill pen",
"quilt, comforter, comfort, puff",
"racer, race car, racing car",
"racket, racquet",
"radiator",
"radio, wireless",
"radio telescope, radio reflector",
"rain barrel",
"recreational vehicle, RV, R.V.",
"reel",
"reflex camera",
"refrigerator, icebox",
"remote control, remote",
"restaurant, eating house, eating place, eatery",
"revolver, six-gun, six-shooter",
"rifle",
"rocking chair, rocker",
"rotisserie",
"rubber eraser, rubber, pencil eraser",
"rugby ball",
"rule, ruler",
"running shoe",
"safe",
"safety pin",
"saltshaker, salt shaker",
"sandal",
"sarong",
"sax, saxophone",
"scabbard",
"scale, weighing machine",
"school bus",
"schooner",
"scoreboard",
"screen, CRT screen",
"screw",
"screwdriver",
"seat belt, seatbelt",
"sewing machine",
"shield, buckler",
"shoe shop, shoe-shop, shoe store",
"shoji",
"shopping basket",
"shopping cart",
"shovel",
"shower cap",
"shower curtain",
"ski",
"ski mask",
"sleeping bag",
"slide rule, slipstick",
"sliding door",
"slot, one-armed bandit",
"snorkel",
"snowmobile",
"snowplow, snowplough",
"soap dispenser",
"soccer ball",
"sock",
"solar dish, solar collector, solar furnace",
"sombrero",
"soup bowl",
"space bar",
"space heater",
"space shuttle",
"spatula",
"speedboat",
"spider web, spider's web",
"spindle",
"sports car, sport car",
"spotlight, spot",
"stage",
"steam locomotive",
"steel arch bridge",
"steel drum",
"stethoscope",
"stole",
"stone wall",
"stopwatch, stop watch",
"stove",
"strainer",
"streetcar, tram, tramcar, trolley, trolley car",
"stretcher",
"studio couch, day bed",
"stupa, tope",
"submarine, pigboat, sub, U-boat",
"suit, suit of clothes",
"sundial",
"sunglass",
"sunglasses, dark glasses, shades",
"sunscreen, sunblock, sun blocker",
"suspension bridge",
"swab, swob, mop",
"sweatshirt",
"swimming trunks, bathing trunks",
"swing",
"switch, electric switch, electrical switch",
"syringe",
"table lamp",
"tank, army tank, armored combat vehicle, armoured combat vehicle",
"tape player",
"teapot",
"teddy, teddy bear",
"television, television system",
"tennis ball",
"thatch, thatched roof",
"theater curtain, theatre curtain",
"thimble",
"thresher, thrasher, threshing machine",
"throne",
"tile roof",
"toaster",
"tobacco shop, tobacconist shop, tobacconist",
"toilet seat",
"torch",
"totem pole",
"tow truck, tow car, wrecker",
"toyshop",
"tractor",
"trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi", # noqa: E501
"tray",
"trench coat",
"tricycle, trike, velocipede",
"trimaran",
"tripod",
"triumphal arch",
"trolleybus, trolley coach, trackless trolley",
"trombone",
"tub, vat",
"turnstile",
"typewriter keyboard",
"umbrella",
"unicycle, monocycle",
"upright, upright piano",
"vacuum, vacuum cleaner",
"vase",
"vault",
"velvet",
"vending machine",
"vestment",
"viaduct",
"violin, fiddle",
"volleyball",
"waffle iron",
"wall clock",
"wallet, billfold, notecase, pocketbook",
"wardrobe, closet, press",
"warplane, military plane",
"washbasin, handbasin, washbowl, lavabo, wash-hand basin",
"washer, automatic washer, washing machine",
"water bottle",
"water jug",
"water tower",
"whiskey jug",
"whistle",
"wig",
"window screen",
"window shade",
"Windsor tie",
"wine bottle",
"wing",
"wok",
"wooden spoon",
"wool, woolen, woollen",
"worm fence, snake fence, snake-rail fence, Virginia fence",
"wreck",
"yawl",
"yurt",
"web site, website, internet site, site",
"comic book",
"crossword puzzle, crossword",
"street sign",
"traffic light, traffic signal, stoplight",
"book jacket, dust cover, dust jacket, dust wrapper",
"menu",
"plate",
"guacamole",
"consomme",
"hot pot, hotpot",
"trifle",
"ice cream, icecream",
"ice lolly, lolly, lollipop, popsicle",
"French loaf",
"bagel, beigel",
"pretzel",
"cheeseburger",
"hotdog, hot dog, red hot",
"mashed potato",
"head cabbage",
"broccoli",
"cauliflower",
"zucchini, courgette",
"spaghetti squash",
"acorn squash",
"butternut squash",
"cucumber, cuke",
"artichoke, globe artichoke",
"bell pepper",
"cardoon",
"mushroom",
"Granny Smith",
"strawberry",
"orange",
"lemon",
"fig",
"pineapple, ananas",
"banana",
"jackfruit, jak, jack",
"custard apple",
"pomegranate",
"hay",
"carbonara",
"chocolate sauce, chocolate syrup",
"dough",
"meat loaf, meatloaf",
"pizza, pizza pie",
"potpie",
"burrito",
"red wine",
"espresso",
"cup",
"eggnog",
"alp",
"bubble",
"cliff, drop, drop-off",
"coral reef",
"geyser",
"lakeside, lakeshore",
"promontory, headland, head, foreland",
"sandbar, sand bar",
"seashore, coast, seacoast, sea-coast",
"valley, vale",
"volcano",
"ballplayer, baseball player",
"groom, bridegroom",
"scuba diver",
"rapeseed",
"daisy",
"yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum", # noqa: E501
"corn",
"acorn",
"hip, rose hip, rosehip",
"buckeye, horse chestnut, conker",
"coral fungus",
"agaric",
"gyromitra",
"stinkhorn, carrion fungus",
"earthstar",
"hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa", # noqa: E501
"bolete",
"ear, spike, capitulum",
"toilet tissue, toilet paper, bathroom tissue",
]
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .activation import build_activation
from .cross_entropy import ParallelCrossEntropyLoss
from .embedding import Embedding, SinePositionalEmbedding, VocabEmbedding, PatchEmbedding
from .layer_norm import LayerNorm, RMSLayerNorm
from .linear import Linear, Linear1D
from .lm_logits import LMLogits
from .mlp import MLP
from .transformer_layer import TransformerLayer
from .attention import MultiheadAttention
from .droppath import DropPath, drop_path
__all__ = [
"Embedding",
"VocabEmbedding",
"SinePositionalEmbedding",
"PatchEmbedding",
"build_activation",
"Linear",
"Linear1D",
"MLP",
"LayerNorm",
"RMSLayerNorm",
"TransformerLayer",
"MultiheadAttention",
"ParallelCrossEntropyLoss",
"LMLogits",
"drop_path",
"DropPath",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment