Commit 9fdb7dab authored by yuguo960516's avatar yuguo960516
Browse files

bloom

parents
Pipeline #150 failed with stages
in 0 seconds
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import atexit
import functools
import logging
import os
import sys
import time
from collections import Counter
from termcolor import colored
from libai.utils.file_io import PathManager
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/utils/logger.py
# --------------------------------------------------------
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(output=None, distributed_rank=0, *, color=True, name="libai", abbrev_name=None):
"""
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs.
By default, will abbreviate "detectron2" to "d2" and leave other
modules unchanged.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = "lb" if name == "libai" else name
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + ".rank{}".format(distributed_rank)
PathManager.mkdirs(os.path.dirname(filename))
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# use 1K buffer if writing to cloud storage
io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
atexit.register(io.close)
return io
"""
Below are some other convenient logging methods.
They are mainly adopted from
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
"""
def _find_caller():
"""
Returns:
str: module name of the caller
tuple: a hashable key to be used to identify different callers
"""
frame = sys._getframe(2)
while frame:
code = frame.f_code
if os.path.join("utils", "logger.") not in code.co_filename:
mod_name = frame.f_globals["__name__"]
if mod_name == "__main__":
mod_name = "libai"
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
frame = frame.f_back
_LOG_COUNTER = Counter()
_LOG_TIMER = {}
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
"""
Log only for the first n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
key (str or tuple[str]): the string(s) can be one of "caller" or
"message", which defines how to identify duplicated logs.
For example, if called with `n=1, key="caller"`, this function
will only log the first call from the same caller, regardless of
the message content.
If called with `n=1, key="message"`, this function will log the
same content only once, even if they are called from different places.
If called with `n=1, key=("caller", "message")`, this function
will not log only if the same caller has logged the same message before.
"""
if isinstance(key, str):
key = (key,)
assert len(key) > 0
caller_module, caller_key = _find_caller()
hash_key = ()
if "caller" in key:
hash_key = hash_key + caller_key
if "message" in key:
hash_key = hash_key + (msg,)
_LOG_COUNTER[hash_key] += 1
if _LOG_COUNTER[hash_key] <= n:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n(lvl, msg, n=1, *, name=None):
"""
Log once per n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
_LOG_COUNTER[key] += 1
if n == 1 or _LOG_COUNTER[key] % n == 1:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
"""
Log no more than once per n seconds.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
last_logged = _LOG_TIMER.get(key, None)
current_time = time.time()
if last_logged is None or current_time - last_logged >= n:
logging.getLogger(name or caller_module).log(lvl, msg)
_LOG_TIMER[key] = current_time
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import concurrent.futures
import io
import logging
from dataclasses import dataclass
from queue import Queue
from threading import Thread
from typing import IO, Callable, Optional, Union
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/iopath/blob/main/iopath/common/non_blocking_io.py
# --------------------------------------------------------
"""
This file is used for asynchronous file operations.
When `opena` is called for the first time for a specific
`PathHandler`, a `NonBlockingIOManager` is instantiated. The
manager returns a `NonBlockingIO` (or `NonBlockingBufferedIO`)
instance to the caller, and the manager maintains all of the
thread management and data management.
"""
@dataclass
class PathData:
"""
Manage the IO job queue and polling thread for a single
path. This is done to ensure that write calls to the same
path are serialized so they are written in the same order
as they were called.
On each `f.write` call where `f` is of type `NonBlockingIO`,
we send the job to the manager where it is enqueued to the
Queue. The polling Thread picks up on the job, executes it,
waits for it to finish, and then continues to poll.
"""
queue: Queue
thread: Thread
class NonBlockingIOManager:
"""
All `opena` calls pass through this class so that it can
keep track of the threads for proper cleanup at the end
of the script. Each path that is opened with `opena` is
assigned a single queue and polling thread that is kept
open until it is cleaned up by `PathManager.async_join()`.
"""
def __init__(
self,
buffered: Optional[bool] = False,
executor: Optional[concurrent.futures.Executor] = None,
) -> None:
"""
Args:
buffered (bool): IO instances will be `NonBlockingBufferedIO`
or `NonBlockingIO` based on this value. This bool is set
manually for each `PathHandler` in `_opena`.
executor: User can optionally attach a custom executor to
perform async operations through `PathHandler.__init__`.
"""
self._path_to_data = {} # Map from path to `PathData` object
self._buffered = buffered
self._IO = NonBlockingBufferedIO if self._buffered else NonBlockingIO
self._pool = executor or concurrent.futures.ThreadPoolExecutor()
def get_non_blocking_io(
self,
path: str,
io_obj: Union[IO[str], IO[bytes]],
callback_after_file_close: Optional[Callable[[None], None]] = None,
buffering: Optional[int] = -1,
) -> Union[IO[str], IO[bytes]]:
"""
Called by `PathHandler._opena` with the path and returns a
`NonBlockingIO` instance.
Args:
path (str): A path str to operate on. This path should be
simplified to ensure that each absolute path has only a single
path str that maps onto it. For example, in `NativePathHandler`,
we can use `os.path.normpath`.
io_obj (IO): a reference to the IO object returned by the
`PathHandler._open` function.
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
buffering (int): An optional argument to set the buffer size for
buffered asynchronous writing.
"""
if not self._buffered and buffering != -1:
raise ValueError(
"NonBlockingIO is not using a buffered writer but `buffering` "
f"arg is set to non-default value of {buffering} != -1."
)
if path not in self._path_to_data:
# Initialize job queue and a polling thread
queue = Queue()
t = Thread(target=self._poll_jobs, args=(queue,))
t.start()
# Store the `PathData`
self._path_to_data[path] = PathData(queue, t)
kwargs = {} if not self._buffered else {"buffering": buffering}
return self._IO(
notify_manager=lambda io_callable: ( # Pass async jobs to manager
self._path_to_data[path].queue.put(io_callable)
),
io_obj=io_obj,
callback_after_file_close=callback_after_file_close,
**kwargs,
)
def _poll_jobs(self, queue: Optional[Callable[[], None]]) -> None:
"""
A single thread runs this loop. It waits for an IO callable to be
placed in a specific path's `Queue` where the queue contains
callable functions. It then waits for the IO job to be completed
before looping to ensure write order.
"""
while True:
# `func` is a callable function (specifically a lambda function)
# and can be any of:
# - func = file.write(b)
# - func = file.close()
# - func = None
func = queue.get() # Blocks until item read.
if func is None: # Thread join signal.
break
self._pool.submit(func).result() # Wait for job to finish.
def _join(self, path: Optional[str] = None) -> bool:
"""
Waits for write jobs for a specific path or waits for all
write jobs for the path handler if no path is provided.
Args:
path (str): Pass in a file path and will wait for the
asynchronous jobs to be completed for that file path.
If no path is passed in, then all threads operating
on all file paths will be joined.
"""
if path and path not in self._path_to_data:
raise ValueError(
f"{path} has no async IO associated with it. "
f"Make sure `opena({path})` is called first."
)
# If a `_close` call fails, we print the error and continue
# closing the rest of the IO objects.
paths_to_close = [path] if path else list(self._path_to_data.keys())
success = True
for _path in paths_to_close:
try:
path_data = self._path_to_data.pop(_path)
path_data.queue.put(None)
path_data.thread.join()
except Exception:
logger = logging.getLogger(__name__)
logger.exception(f"`NonBlockingIO` thread for {_path} failed to join.")
success = False
return success
def _close_thread_pool(self) -> bool:
"""
Closes the ThreadPool.
"""
try:
self._pool.shutdown()
except Exception:
logger = logging.getLogger(__name__)
logger.exception("`NonBlockingIO` thread pool failed to close.")
return False
return True
# NOTE: We currently only support asynchronous writes (not reads).
class NonBlockingIO(io.IOBase):
def __init__(
self,
notify_manager: Callable[[Callable[[], None]], None],
io_obj: Union[IO[str], IO[bytes]],
callback_after_file_close: Optional[Callable[[None], None]] = None,
) -> None:
"""
Returned to the user on an `opena` call. Uses a Queue to manage the
IO jobs that need to be run to ensure order preservation and a
polling Thread that checks the Queue. Implementation for these are
lifted to `NonBlockingIOManager` since `NonBlockingIO` closes upon
leaving the context block.
NOTE: Writes to the same path are serialized so they are written in
the same order as they were called but writes to distinct paths can
happen concurrently.
Args:
notify_manager (Callable): a callback function passed in from the
`NonBlockingIOManager` so that all IO jobs can be stored in
the manager. It takes in a single argument, namely another
callable function.
Example usage:
```
notify_manager(lambda: file.write(data))
notify_manager(lambda: file.close())
```
Here, we tell `NonBlockingIOManager` to add a write callable
to the path's Queue, and then to add a close callable to the
path's Queue. The path's polling Thread then executes the write
callable, waits for it to finish, and then executes the close
callable. Using `lambda` allows us to pass callables to the
manager.
io_obj (IO): a reference to the IO object returned by the
`PathHandler._open` function.
callback_after_file_close (Callable): An optional argument that can
be passed to perform operations that depend on the asynchronous
writes being completed. The file is first written to the local
disk and then the callback is executed.
"""
super().__init__()
self._notify_manager = notify_manager
self._io = io_obj
self._callback_after_file_close = callback_after_file_close
self._close_called = False
def readable(self) -> bool:
return False
def writable(self) -> bool:
return True
def seekable(self) -> bool:
return True
def write(self, b: Union[bytes, bytearray]) -> None:
"""
Called on `f.write()`. Gives the manager the write job to call.
"""
self._notify_manager(lambda: self._io.write(b))
def seek(self, offset: int, whence: int = 0) -> int:
"""
Called on `f.seek()`.
"""
self._notify_manager(lambda: self._io.seek(offset, whence))
def tell(self) -> int:
"""
Called on `f.tell()`.
"""
raise ValueError("ioPath async writes does not support `tell` calls.")
def truncate(self, size: int = None) -> int:
"""
Called on `f.truncate()`.
"""
self._notify_manager(lambda: self._io.truncate(size))
def close(self) -> None:
"""
Called on `f.close()` or automatically by the context manager.
We add the `close` call to the file's queue to make sure that
the file is not closed before all of the write jobs are complete.
"""
# `ThreadPool` first closes the file and then executes the callback.
# We only execute the callback once even if there are multiple
# `f.close` calls.
self._notify_manager(lambda: self._io.close())
if not self._close_called and self._callback_after_file_close:
self._notify_manager(self._callback_after_file_close)
self._close_called = True
# NOTE: To use this class, use `buffered=True` in `NonBlockingIOManager`.
# NOTE: This class expects the IO mode to be buffered.
class NonBlockingBufferedIO(io.IOBase):
MAX_BUFFER_BYTES = 10 * 1024 * 1024 # 10 MiB
def __init__(
self,
notify_manager: Callable[[Callable[[], None]], None],
io_obj: Union[IO[str], IO[bytes]],
callback_after_file_close: Optional[Callable[[None], None]] = None,
buffering: int = -1,
) -> None:
"""
Buffered version of `NonBlockingIO`. All write data is stored in an
IO buffer until the buffer is full, or `flush` or `close` is called.
Args:
Same as `NonBlockingIO` args.
buffering (int): An optional argument to set the buffer size for
buffered asynchronous writing.
"""
super().__init__()
self._notify_manager = notify_manager
self._io = io_obj
self._callback_after_file_close = callback_after_file_close
self._buffers = [io.BytesIO()]
self._buffer_size = buffering if buffering > 0 else self.MAX_BUFFER_BYTES
self._close_called = False
def readable(self) -> bool:
return False
def writable(self) -> bool:
return True
def seekable(self) -> bool:
return False
def write(self, b: Union[bytes, bytearray]) -> None:
"""
Called on `f.write()`. Gives the manager the write job to call.
"""
buffer = self._buffers[-1]
with memoryview(b) as view:
buffer.write(view)
if buffer.tell() < self._buffer_size:
return
self.flush()
def close(self) -> None:
"""
Called on `f.close()` or automatically by the context manager.
We add the `close` call to the file's queue to make sure that
the file is not closed before all of the write jobs are complete.
"""
self.flush()
# Close the last buffer created by `flush`.
self._notify_manager(lambda: self._buffers[-1].close())
# `ThreadPool` first closes the file and then executes the callback.
self._notify_manager(lambda: self._io.close())
if not self._close_called and self._callback_after_file_close:
self._notify_manager(self._callback_after_file_close)
self._close_called = True
def flush(self) -> None:
"""
Called on `f.write()` if the buffer is filled (or overfilled). Can
also be explicitly called by user.
NOTE: Buffering is used in a strict manner. Any buffer that exceeds
`self._buffer_size` will be broken into multiple write jobs where
each has a write call with `self._buffer_size` size.
"""
buffer = self._buffers[-1]
if buffer.tell() == 0:
return
pos = 0
total_size = buffer.seek(0, io.SEEK_END)
view = buffer.getbuffer()
# Chunk the buffer in case it is larger than the buffer size.
while pos < total_size:
item = view[pos : pos + self._buffer_size]
# `item=item` is needed due to Python's late binding closures.
self._notify_manager(lambda item=item: self._io.write(item))
pos += self._buffer_size
# Close buffer immediately after being written to file and create
# a new buffer.
self._notify_manager(lambda: buffer.close())
self._buffers.append(io.BytesIO())
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from time import perf_counter
from typing import Optional
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/common/timer.py
# --------------------------------------------------------
class Timer:
"""
A timer which computes the time elapsed since the start/reset of the timer.
"""
def __init__(self):
self.reset()
def reset(self):
"""
Reset the timer.
"""
self._start = perf_counter()
self._paused: Optional[float] = None
self._total_paused = 0
self._count_start = 1
def pause(self):
"""
Pause the timer.
"""
if self._paused is not None:
raise ValueError("Trying to pause a Timer that is already paused!")
self._paused = perf_counter()
def is_paused(self) -> bool:
"""
Returns:
bool: whether the timer is currently paused
"""
return self._paused is not None
def resume(self):
"""
Resume the timer.
"""
if self._paused is None:
raise ValueError("Trying to resume a Timer that is not paused!")
self._total_paused += perf_counter() - self._paused
self._paused = None
self._count_start += 1
def seconds(self) -> float:
"""
Returns:
(float): the total number of seconds since the start/reset of the
timer, excluding the time when the timer is paused.
"""
if self._paused is not None:
end_time: float = self._paused # type: ignore
else:
end_time = perf_counter()
return end_time - self._start - self._total_paused
def avg_seconds(self) -> float:
"""
Returns:
(float): the average number of seconds between every start/reset and
pause.
"""
return self.seconds() / self._count_start
# 模型名称
modelName=BLOOM
# 模型描述
modelDescription=bloom-7b1
# 应用场景(多个标签以英文逗号分割)
appScenario=文学创作,文本生成
# 框架类型(多个标签以英文逗号分割)
frameType=OneFlow,Libai
from omegaconf import DictConfig
from libai.config import LazyCall
from projects.BLOOM.modeling.bloom_model import BloomModel
cfg = dict(
# model
vocab_size=250880,
hidden_size=64,
hidden_layers=2,
n_head=8,
padding_idx=3,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
pretraining_tp=1,
slow_but_exact=False,
amp_enabled=False,
# Inference
is_encoder_decoder=False,
max_length=512,
min_length=0,
do_sample=False,
early_stopping=False,
num_beams=1,
num_beam_groups=1,
diversity_penalty=0.0,
temperature=1.0,
top_k=50,
top_p=1.0,
typical_p=1.0,
repetition_penalty=1.0,
length_penalty=1.0,
no_repeat_ngram_size=0,
encoder_no_repeat_ngram_size=0,
num_return_sequences=1,
chunk_size_feed_forward=0,
output_scores=False,
forced_bos_token_id=None,
forced_eos_token_id=None,
remove_invalid_values=False,
exponential_decay_length_penalty=None,
use_cache=True,
# Tokenizer
pad_token_id=3,
eos_token_id=2,
bos_token_id=1,
sep_token_id=None,
decoder_start_token_id=None,
)
cfg = DictConfig(cfg)
glm_model = LazyCall(BloomModel)(cfg=cfg)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
def bloom_gelu_forward(x):
"""
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple
implementation (inference) to make the model jitable.
Args:
x (`torch.tensor`, *required*):
input hidden states
"""
return x * 0.5 * (1.0 + flow.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
def bloom_gelu_back(g, x):
"""
gradient of tanh approximation of gelu gradient of actual gelu is:
0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
Args:
g (`torch.tensor`, *required*):
gradient output tensor
x (`torch.tensor`, *required*):
input tensor
"""
x = x[0]
tanh_out = flow.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
1 + tanh_out
)
return ff * g
class GeLUFunction(flow.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return bloom_gelu_forward(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = bloom_gelu_back(grad_output, input)
return tmp
class BloomGelu(nn.Module):
"""
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make
the model torchscriptable and use the autograd function in training mode to get the accurate
results of the gradients Partly copied from Megatron-DeepSpeed code and adapted for our needs
See here why autograd functions are not torchscriptable:
https://github.com/pytorch/pytorch/issues/22329
"""
def __init__(self):
super().__init__()
def forward(self, x):
if self.training:
return GeLUFunction.apply(x)
else:
return bloom_gelu_forward(x)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow as flow
from oneflow import nn
from oneflow.nn import functional as F
from libai.layers import Linear
def dropout_add(x, residual, prob, training):
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
class BloomAttention(nn.Module):
def __init__(
self,
hidden_size,
n_head,
hidden_dropout,
attention_dropout,
pretraining_tp,
slow_but_exact,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
layer_idx=0,
):
super().__init__()
self.pretraining_tp = pretraining_tp
self.slow_but_exact = slow_but_exact
self.hidden_size = hidden_size
self.num_heads = n_head
self.head_dim = self.hidden_size // self.num_heads
self.split_size = self.hidden_size
self.hidden_dropout = hidden_dropout
if output_layer_init_method is None:
output_layer_init_method = init_method
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"`hidden_size` must be divisible by num_heads "
f"(got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
# Layer-wise attention scaling
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
self.query_key_value = Linear(
self.hidden_size,
3 * self.hidden_size,
bias=True,
parallel="col",
init_method=init_method,
layer_idx=layer_idx,
)
self.dense = Linear(
self.hidden_size,
self.hidden_size,
parallel="row",
init_method=output_layer_init_method,
layer_idx=layer_idx,
)
self.attention_dropout = nn.Dropout(attention_dropout)
def _split_heads(self, fused_qkv):
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share
same memory storage as `fused_qkv`
Args:
fused_qkv (`torch.tensor`, *required*):
[batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim]
key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
def _merge_heads(self, x):
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // self.num_heads
# First view to decompose the batch size
# batch_size * num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
def forward(
self,
hidden_states,
residual,
alibi,
attention_mask,
layer_past=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, q_length, self.head_dim
)
key_layer = key_layer.permute(0, 2, 3, 1).reshape(
batch_size * self.num_heads, self.head_dim, q_length
)
value_layer = value_layer.transpose(1, 2).reshape(
batch_size * self.num_heads, q_length, self.head_dim
)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = flow.cat((past_key, key_layer), dim=2)
value_layer = flow.cat((past_value, value_layer), dim=1)
_, _, kv_length = key_layer.shape
if use_cache is True:
present = (key_layer, value_layer)
else:
present = None
matmul_result = flow.baddbmm(
alibi,
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
input_dtype = attention_scores.dtype
attn_weights = flow.masked_fill(
attention_scores, attention_mask, flow.finfo(attention_scores.dtype).min
)
attention_probs = F.softmax(attn_weights, dim=-1).to(input_dtype)
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
attention_probs_reshaped = attention_probs.view(
batch_size * self.num_heads, q_length, kv_length
)
context_layer = flow.bmm(attention_probs_reshaped, value_layer)
context_layer = self._merge_heads(context_layer)
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = self.hidden_size / self.pretraining_tp
output_tensor = flow.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
output_tensor = self.dense(context_layer)
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
outputs = (output_tensor, present)
if output_attentions:
outputs += (attention_probs,)
return outputs
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
from oneflow import nn
from libai.config import configurable
from libai.inference.generator.generation_utils import Generator
from libai.layers import Embedding, LayerNorm, LMLogits
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist
from projects.BLOOM.modeling.mask import _expand_mask, _make_causal_mask, build_alibi_tensor
from projects.BLOOM.modeling.transformers import BloomBlock
class BloomModel(nn.Module):
@configurable
def __init__(
self,
vocab_size,
hidden_size,
hidden_layers,
n_head,
padding_idx,
pretraining_tp=1,
slow_but_exact=False,
initializer_range=0.02,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0,
attention_dropout=0,
amp_enabled=False,
layer_norm_epsilon=1e-12,
cfg=None,
):
super().__init__()
self.cfg = cfg
self.embed_dim = hidden_size
self.num_heads = n_head
self.hidden_layers = hidden_layers
init_method = init_method_normal(initializer_range)
scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers)
self.word_embeddings = Embedding(
vocab_size,
self.embed_dim,
padding_idx=padding_idx,
init_method=init_method,
amp_enabled=amp_enabled,
layer_idx=0,
)
self.word_embeddings_layernorm = LayerNorm(
self.embed_dim, eps=layer_norm_epsilon, layer_idx=0
)
self.h = flow.nn.ModuleList(
[
BloomBlock(
hidden_size=hidden_size,
n_head=n_head,
layer_norm_epsilon=layer_norm_epsilon,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
pretraining_tp=pretraining_tp,
slow_but_exact=slow_but_exact,
init_method=init_method,
output_layer_init_method=scaled_init_method,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, # noqa
layer_idx=i,
)
for i in range(hidden_layers)
]
)
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=layer_norm_epsilon, layer_idx=hidden_layers - 1)
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"hidden_layers": cfg.hidden_layers,
"n_head": cfg.n_head,
"padding_idx": cfg.padding_idx,
"pretraining_tp": cfg.pretraining_tp,
"slow_but_exact": cfg.slow_but_exact,
"apply_residual_connection_post_layernorm": cfg.apply_residual_connection_post_layernorm, # noqa
"hidden_dropout": cfg.hidden_dropout,
"attention_dropout": cfg.attention_dropout,
"amp_enabled": cfg.amp_enabled,
"layer_norm_epsilon": cfg.layer_norm_epsilon,
"cfg": cfg,
}
def _prepare_attn_mask(
self,
attention_mask,
input_shape,
past_key_values_length,
):
combined_attention_mask = None
_, src_length = input_shape
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
"""-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = (
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
"""
Prepare the head mask if needed.
Args:
head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`,
*optional*):
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for
discard).
num_hidden_layers (`int`):
The number of hidden layers in the model.
is_attention_chunked: (`bool`, *optional*, defaults to `False`):
Whether or not the attentions scores are computed by chunks or not.
Returns:
`torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x
seq_length]` or list with `[None]` for each layer.
"""
if head_mask is not None:
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
if is_attention_chunked is True:
head_mask = head_mask.unsqueeze(-1)
else:
head_mask = [None] * num_hidden_layers
return head_mask
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
):
input_ids = (
input_ids.to_global(placement=dist.get_layer_placement(0))
if input_ids is not None
else input_ids
)
attention_mask = (
attention_mask.to_global(placement=dist.get_layer_placement(0))
if attention_mask is not None
else attention_mask
)
head_mask = (
head_mask.to_global(placement=dist.get_layer_placement(0))
if head_mask is not None
else head_mask
)
inputs_embeds = (
inputs_embeds.to_global(placement=dist.get_layer_placement(0))
if inputs_embeds is not None
else inputs_embeds
)
if input_ids is not None:
batch_size, seq_length = input_ids.shape
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
head_mask = self.get_head_mask(head_mask, self.hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = flow.ones(
(batch_size, seq_length_with_past),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
alibi = build_alibi_tensor(attention_mask, self.num_heads, hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
hidden_states = self.ln_f(hidden_states)
return {"last_hidden_state": hidden_states, "past_key_values": presents}
class BloomForCausalLM(nn.Module, Generator):
@configurable
def __init__(
self,
vocab_size,
hidden_size,
hidden_layers,
n_head,
padding_idx,
pretraining_tp=1,
slow_but_exact=False,
initializer_range=0.02,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0,
attention_dropout=0,
amp_enabled=False,
layer_norm_epsilon=1e-12,
cfg=None,
):
super().__init__()
self.cfg = cfg
self.transformer = BloomModel(
vocab_size=vocab_size,
hidden_size=hidden_size,
hidden_layers=hidden_layers,
n_head=n_head,
padding_idx=padding_idx,
pretraining_tp=pretraining_tp,
slow_but_exact=slow_but_exact,
initializer_range=initializer_range,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
amp_enabled=amp_enabled,
layer_norm_epsilon=layer_norm_epsilon,
cfg=cfg,
)
self.lm_head = LMLogits(vocab_size, bias=False)
@classmethod
def from_config(cls, cfg):
return {
"vocab_size": cfg.vocab_size,
"hidden_size": cfg.hidden_size,
"hidden_layers": cfg.hidden_layers,
"n_head": cfg.n_head,
"padding_idx": cfg.padding_idx,
"pretraining_tp": cfg.pretraining_tp,
"slow_but_exact": cfg.slow_but_exact,
"apply_residual_connection_post_layernorm": cfg.apply_residual_connection_post_layernorm, # noqa
"hidden_dropout": cfg.hidden_dropout,
"attention_dropout": cfg.attention_dropout,
"amp_enabled": cfg.amp_enabled,
"layer_norm_epsilon": cfg.layer_norm_epsilon,
"cfg": cfg,
}
def forward(
self,
input_ids=None,
attention_mask=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
):
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = transformer_outputs["last_hidden_state"]
lm_logits = self.lm_head(hidden_states, self.transformer.word_embeddings.weight)
return {
"logits": lm_logits,
"past_key_values": transformer_outputs["past_key_values"],
"hidden_states": transformer_outputs["last_hidden_state"],
# "attentions": transformer_outputs.attentions,
}
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs,
) -> dict:
# only last token for input_ids if past is not None
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
past_key_values = self._convert_to_bloom_cache(past_key_values)
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
def _reorder_cache(self, past, beam_idx):
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device)
for layer_past in past
for past_state in layer_past
}
reordered_past = tuple(
(
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
)
for layer_past in standardized_past
)
return self._convert_to_bloom_cache(reordered_past)
def _convert_to_standard_cache(
past_key_value,
batch_size,
):
"""
Standardizes the format of the cache so as to match most implementations,
i.e. to tuple(tuple([batch_size, num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
def _convert_to_bloom_cache(past_key_value):
"""
Converts the cache to the format expected by Bloom,
i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import oneflow as flow
from libai.utils import distributed as dist
def _make_causal_mask(input_ids_shape, past_key_values_length):
"""
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = flow.ones(
(target_length, target_length + past_key_values_length),
dtype=flow.bool,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
seq_ids = flow.arange(
target_length,
dtype=flow.long,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
if past_key_values_length > 0:
mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(
batch_size, 1, target_length, target_length + past_key_values_length
)
return expanded_mask
def _expand_mask(mask, tgt_length):
"""
Expands attention_mask from `[batch_size, src_length]` to
`[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = ~(mask[:, None, None, :].to(flow.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
def build_alibi_tensor(attention_mask, num_heads, dtype):
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = flow.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=attention_mask.placement,
)
powers = flow.arange(
1,
1 + closest_power_of_2,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=attention_mask.placement,
)
slopes = flow.pow(base, powers)
if closest_power_of_2 != num_heads:
extra_base = flow.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=attention_mask.placement,
)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = flow.arange(
1,
1 + 2 * num_remaining_heads,
2,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=attention_mask.placement,
)
slopes = flow.cat([slopes, flow.pow(extra_base, extra_powers)], dim=0)
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import oneflow as flow
import oneflow.nn.functional as F
from oneflow import nn
from libai.layers import Linear
from projects.BLOOM.modeling.activation import BloomGelu
from projects.BLOOM.modeling.attention import dropout_add
class BloomMLP(nn.Module):
def __init__(
self,
hidden_size,
pretraining_tp,
slow_but_exact,
hidden_dropout,
init_method=None,
output_layer_init_method=None,
layer_idx=0,
):
super().__init__()
hidden_size = hidden_size
if output_layer_init_method is None:
output_layer_init_method = init_method
self.pretraining_tp = pretraining_tp
self.slow_but_exact = slow_but_exact
self.dense_h_to_4h = Linear(
hidden_size,
4 * hidden_size,
parallel="col",
init_method=init_method,
layer_idx=layer_idx,
)
self.gelu_impl = BloomGelu()
self.dense_4h_to_h = Linear(
4 * hidden_size,
hidden_size,
parallel="row",
init_method=output_layer_init_method,
layer_idx=layer_idx,
)
self.hidden_dropout = hidden_dropout
def forward(self, hidden_states, residual):
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = flow.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
return output
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from oneflow import nn
from libai.layers import LayerNorm
from libai.utils import distributed as dist
from projects.BLOOM.modeling.attention import BloomAttention
from projects.BLOOM.modeling.mlp import BloomMLP
class BloomBlock(nn.Module):
def __init__(
self,
hidden_size,
n_head,
layer_norm_epsilon,
hidden_dropout,
attention_dropout,
pretraining_tp,
slow_but_exact,
init_method,
output_layer_init_method,
apply_residual_connection_post_layernorm,
layer_idx=0,
):
super().__init__()
hidden_size = hidden_size
self.input_layernorm = LayerNorm(hidden_size, eps=layer_norm_epsilon, layer_idx=layer_idx)
self.num_heads = n_head
self.self_attention = BloomAttention(
hidden_size=hidden_size,
n_head=n_head,
hidden_dropout=hidden_dropout,
attention_dropout=attention_dropout,
pretraining_tp=pretraining_tp,
slow_but_exact=slow_but_exact,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_idx=layer_idx,
)
self.post_attention_layernorm = LayerNorm(
hidden_size, eps=layer_norm_epsilon, layer_idx=layer_idx
)
self.mlp = BloomMLP(
hidden_size,
pretraining_tp,
slow_but_exact,
hidden_dropout,
init_method,
output_layer_init_method,
layer_idx,
)
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.layer_idx = layer_idx
def forward(
self,
hidden_states,
alibi,
attention_mask,
layer_past=None,
head_mask=None,
use_cache: bool = False,
output_attentions: bool = False,
):
# Change placement for pipeline parallelsim
hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))
alibi = alibi.to_global(placement=dist.get_layer_placement(self.layer_idx))
# hidden_states shape: (batch_size, seq_length, hidden_size)
if attention_mask is not None:
attention_mask = attention_mask.to_global(
placement=dist.get_layer_placement(self.layer_idx)
)
layernorm_output = self.input_layernorm(hidden_states)
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attn_outputs = self.self_attention(
layernorm_output,
residual,
layer_past=layer_past,
attention_mask=attention_mask,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output, residual)
if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]
return outputs # hidden_states, present, attentions
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from libai.models.utils import ModelLoaderHuggerFace, ModelLoaderLiBai
class BlooMLoaderHuggerFace(ModelLoaderHuggerFace):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
"""NOTE: base_model_prefix_1 is BLOOM's prefix in Transformers.
base_model_prefix_2 is BLOOM's prefix in LiBai."""
self.base_model_prefix_1 = "transformer"
self.base_model_prefix_2 = "transformer"
def _convert_state_dict(self, flow_state_dict, cfg):
"""Convert state_dict's keys to match model.
Args:
flow_state_dict (OrderedDict): model state dict.
cfg (dict): model's default config dict in LiBai.
Returns:
OrderedDict: flow state dict.
"""
# The converted checkpoint.
oneflow_state_dict = flow_state_dict.copy()
old_keys = list(oneflow_state_dict.keys())
# prefix
has_prefix = any(s.startswith(self.base_model_prefix_1) for s in oneflow_state_dict)
prefix2 = "transformer." if has_prefix else ""
# Convert layers.
for key in old_keys:
oneflow_state_dict[prefix2 + key] = oneflow_state_dict.pop(key)
return oneflow_state_dict
def _load_config_from_json(self, config_file):
"""load config from `config.json`, and update default config.
Args:
config_file (str): Path of config file.
"""
with open(config_file, mode="r", encoding="utf-8") as f:
cfg_dict = json.load(f)
self._update_cfg("hidden_layers", cfg_dict["n_layer"])
self._update_cfg("hidden_size", cfg_dict["n_embed"])
self._update_cfg("n_head", cfg_dict["num_attention_heads"])
# update libai_cfg by config.json
for k, v in cfg_dict.items():
self._update_cfg(k, v)
# update libai_cfg by kwargs
for k, v in self.kwargs.items():
self._update_cfg(k, v)
self._update_cfg_log()
class BlooMLoaderLibai(ModelLoaderLiBai):
def __init__(self, model, libai_cfg, pretrained_model_path, **kwargs):
super().__init__(model, libai_cfg, pretrained_model_path, **kwargs)
self.base_model_prefix_2 = "transformer"
# CLIP
Contributor{Xingyu.Liao: sherlockliao01@gmail.com}
> NOTE: We only support inference right now. Stay tuned for training part.
CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.
## Approach
![CLIP](CLIP.png)
## Usage
```python
import clip
import oneflow as flow
from PIL import Image
device = "cuda" if flow.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
image = (
preprocess(Image.open("CLIP.png"))
.unsqueeze(0)
.to_global(sbp=flow.sbp.split(0), placement=flow.placement("cuda", ranks=[0]))
)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to_global(
sbp=flow.sbp.split(0), placement=flow.placement("cuda", ranks=[0])
)
with flow.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
logits_per_image, logits_per_text = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
```
from .clip import load, tokenize
# --------------------------------------------------------
# Borrow code from:
# https://github.com/openai/CLIP/tree/main/clip/clip.py
# --------------------------------------------------------
import hashlib
import os
import urllib
import warnings
from typing import List, Union
import oneflow as flow
import torch
from flowvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from PIL import Image
from tqdm import tqdm
from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
from flowvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
__all__ = ["available_models", "load", "tokenize"]
_tokenizer = _Tokenizer()
# noqa:
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # noqa: E501
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # noqa: E501
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", # noqa: E501
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # noqa: E501
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # noqa: E501
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # noqa: E501
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # noqa: E501
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", # noqa: E501
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", # noqa: E501
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; "
"re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose(
[
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(
name: str,
device: Union[str, torch.device] = "cuda" if flow.cuda.is_available() else "cpu",
download_root: str = None,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a
model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : flow.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], flow.Tensor]
A flowvision transform that converts a PIL image into a tensor that
the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, "rb") as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location="cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
state_dict = torch.load(opened_file, map_location="cpu")
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.img_size)
def tokenize(
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
) -> Union[flow.IntTensor, flow.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
truncate: bool
Whether to truncate the text in case its encoding is longer than the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens,
shape = [number of input strings, context_length].
"""
if isinstance(texts, str):
texts = [texts]
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = flow.zeros(len(all_tokens), context_length, dtype=flow.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
if truncate:
tokens = tokens[:context_length]
tokens[-1] = eot_token
else:
raise RuntimeError(
f"Input {texts[i]} is too long for context length {context_length}"
)
result[i, : len(tokens)] = flow.tensor(tokens, dtype=flow.int)
return result
# --------------------------------------------------------
# Borrow code from:
# https://github.com/openai/CLIP/tree/main/clip/model.py
# --------------------------------------------------------
from collections import OrderedDict
from typing import Dict, Tuple, Union
import numpy as np
import oneflow as flow
import torch
from oneflow import nn
from libai.layers import MLP, Embedding, LayerNorm, Linear, MultiheadAttention, TransformerLayer
from libai.layers.activation import build_activation
from libai.layers.attention import AttnMaskType
from libai.models import VisionTransformer as ViT
from libai.utils import distributed as dist
from libai.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from .ops import multi_head_attention_forward
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed
# after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool,
# and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict(
[
("-1", nn.AvgPool2d(stride)),
(
"0",
nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False),
),
("1", nn.BatchNorm2d(planes * self.expansion)),
]
)
)
def forward(self, x: flow.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
flow.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
2, 0, 1
) # NCHW -> (HW)NC
x = flow.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=flow.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False,
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to flowvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1,
with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is
prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.to(dtype=self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class MLPClip(MLP):
def __init__(
self,
hidden_size,
ffn_hidden_size,
output_dropout_prob=0,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
*,
layer_idx=0,
):
super().__init__(
hidden_size,
ffn_hidden_size,
output_dropout_prob,
init_method,
output_layer_init_method,
bias_gelu_fusion,
bias_dropout_fusion,
layer_idx=layer_idx,
)
if not bias_gelu_fusion:
self.activation_func = build_activation("quick_gelu")
class TransformerLayerClip(TransformerLayer):
def __init__(
self,
hidden_size,
ffn_hidden_size,
num_attention_heads,
is_decoder=False,
attention_dropout_prob=0,
output_dropout_prob=0,
drop_path_prob=0,
layernorm_epsilon=0.00001,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
bias_gelu_fusion=False,
bias_dropout_fusion=False,
scale_mask_softmax_fusion=False,
apply_query_key_layer_scaling=False,
apply_residual_post_layernorm=False,
attn_mask_type=AttnMaskType.padding,
*,
layer_idx=0,
):
super().__init__(
hidden_size,
ffn_hidden_size,
num_attention_heads,
is_decoder,
attention_dropout_prob,
output_dropout_prob,
drop_path_prob,
layernorm_epsilon,
init_method,
output_layer_init_method,
bias_gelu_fusion,
bias_dropout_fusion,
scale_mask_softmax_fusion,
apply_query_key_layer_scaling,
apply_residual_post_layernorm,
attn_mask_type,
layer_idx=layer_idx,
)
self.mlp = MLPClip(
self.hidden_size,
self.ffn_hidden_size,
self.output_dropout_prob,
self.init_method,
output_layer_init_method=self.output_layer_init_method,
bias_gelu_fusion=self.bias_gelu_fusion,
bias_dropout_fusion=self.bias_dropout_fusion,
layer_idx=self.layer_idx,
)
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: flow.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.attn_mask = attn_mask
self.resblocks = nn.ModuleList(
[TransformerLayerClip(width, 4 * width, heads, layer_idx=i) for i in range(layers)]
)
def forward(self, x: flow.Tensor):
for layer in self.resblocks:
x = layer(x, self.attn_mask)
return x
class VisionTransformer(ViT):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=192,
depth=12,
num_heads=3,
mlp_ratio=4,
drop_rate=0,
attn_drop_rate=0,
drop_path_rate=0,
num_classes=1000,
loss_func=None,
):
super().__init__(
img_size,
patch_size,
in_chans,
embed_dim,
depth,
num_heads,
mlp_ratio,
drop_rate,
attn_drop_rate,
drop_path_rate,
num_classes,
loss_func,
)
self.ln_pre = LayerNorm(embed_dim, layer_idx=0)
self.head = Linear(embed_dim, num_classes, bias=False, layer_idx=-1)
def forward_features(self, x):
# patch embedding
x = self.patch_embed(x)
cls_token = self.cls_token.expand(
x.shape[0], -1, -1
) # stole cls_tokens impl from Phil Wang, thanks
cls_token = cls_token.to_global(sbp=x.sbp, placement=cls_token.placement)
x = flow.cat((cls_token, x), dim=1)
# position embedding
pos_embed = self.pos_embed.expand(x.shape[0], -1, -1)
pos_embed = pos_embed.to_global(sbp=x.sbp, placement=pos_embed.placement)
x = self.pos_drop(x + pos_embed)
# layernorm_pre
x = self.ln_pre(x)
# transformer block
x = self.blocks(x)
return x
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width,
).to_global(sbp=flow.sbp.broadcast, placement=dist.get_layer_placement(0))
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
img_size=image_resolution,
patch_size=vision_patch_size,
embed_dim=vision_width,
depth=vision_layers,
num_heads=vision_heads,
num_classes=embed_dim,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
flow.empty(
self.context_length,
transformer_width,
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=dist.get_layer_placement(0),
)
)
self.ln_final = LayerNorm((transformer_width,), layer_idx=-1)
self.text_projection = nn.Parameter(
flow.empty(
transformer_width,
embed_dim,
sbp=flow.sbp.broadcast,
placement=dist.get_layer_placement(0),
)
)
self.logit_scale = nn.Parameter(
flow.ones([], sbp=flow.sbp.broadcast, placement=dist.get_layer_placement(0))
* np.log(1 / 0.07)
)
self.initialize_parameters()
def initialize_parameters(self):
if hasattr(self.visual, "patch_embed"):
nn.init.zeros_(self.visual.patch_embed.proj.bias)
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [
self.visual.layer1,
self.visual.layer2,
self.visual.layer3,
self.visual.layer4,
]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.self_attention.query_key_value.weight, std=attn_std)
nn.init.normal_(block.self_attention.dense.weight, std=proj_std)
nn.init.normal_(block.mlp.dense_h_to_4h.weight, std=fc_std)
nn.init.normal_(block.mlp.dense_4h_to_h.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = flow.ones(
self.context_length,
self.context_length,
sbp=flow.sbp.broadcast,
placement=dist.get_layer_placement(0),
)
mask = flow.tril(mask) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image)["prediction_scores"]
def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
# x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
# x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = (
x[flow.arange(x.shape[0], sbp=x.sbp, placement=x.placement), text.argmax(dim=-1)]
@ self.text_projection
)
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(dtype=flow.float16)
if l.bias is not None:
l.bias.data = l.bias.data.to(dtype=flow.float16)
if isinstance(l, MultiheadAttention):
for attr in ["query_key_value", "dense"]:
layer = getattr(l, attr)
weight = getattr(layer, "weight")
if weight is not None:
weight.data = weight.data.to(dtype=flow.float16)
bias = getattr(layer, "bias")
if bias is not None:
bias.data = bias.data.to(dtype=flow.float16)
if hasattr(l, "text_projection"):
attr = getattr(l, "text_projection")
if attr is not None:
attr.data = attr.data.to(dtype=flow.float16)
if hasattr(l, "proj"):
attr = getattr(l, "proj")
if attr is not None:
attr.weight.data = attr.weight.data.to(dtype=flow.float16)
model.apply(_convert_weights_to_fp16)
def load_tensor(tensor_lhs: flow.Tensor, tensor_rhs: torch.Tensor):
tensor_rhs = flow.Tensor(
tensor_rhs.cpu().numpy(),
sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
placement=flow.env.all_device_placement("cuda"),
).to_global(sbp=tensor_lhs.sbp, placement=tensor_lhs.placement)
tensor_lhs.data.copy_(tensor_rhs.data)
def load_weights(model: nn.Module, state_dict: Dict):
model_state_dict = model.state_dict()
incorrect_shapes = []
for k in list(state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(state_dict[k].shape)
if shape_model != shape_checkpoint:
incorrect_shapes.append((k, shape_checkpoint, shape_model))
state_dict.pop(k)
unexpected_keys = []
for key, value in state_dict.items():
if key not in model_state_dict:
unexpected_keys.append(key)
# skip this key
continue
model_state_dict.pop(key)
load_tensor(model.state_dict()[key], value)
missing_keys = list(model_state_dict.keys())
for k, shape_checkpoint, shape_model in incorrect_shapes:
print(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if missing_keys:
print(get_missing_parameters_message(missing_keys))
if unexpected_keys:
print(get_unexpected_parameters_message(unexpected_keys))
def convert_qkv_weight(qkv_weight, num_heads):
qkv_weight = qkv_weight.view([3, num_heads, 64, num_heads * 64])
qkv_weight = (
qkv_weight.permute(1, 0, 2, 3).contiguous().view(3 * num_heads * 64, num_heads * 64)
)
return qkv_weight
def convert_qkv_bias(qkv_bias, num_heads):
qkv_bias = qkv_bias.view(3, num_heads, 64)
qkv_bias = qkv_bias.permute(1, 0, 2).contiguous().view(-1)
return qkv_bias
def change_vit_state_dict(state_dict, visual_num_heads, text_num_heads):
new_state_dict = {}
for key, value in state_dict.items():
# change prefix
if "visual.transformer.resblocks" in key:
key = key.replace("visual.transformer.resblocks", "visual.blocks")
# change "ln_1" to "input_layernorm"
if "ln_1" in key:
key = key.replace("ln_1", "input_layernorm")
# change "ln_2" to "post_attention_layernorm"
if "ln_2" in key:
key = key.replace("ln_2", "post_attention_layernorm")
# change "attn.out_proj" to "attention.dense"
if "attn.out_proj" in key:
key = key.replace("attn.out_proj", "attention.dense")
# change "attn" to "attention.query_key_value"
if "attn.in_proj_weight" in key:
key = key.replace("attn.in_proj_weight", "attention.query_key_value.weight")
if "visual" not in key:
value = convert_qkv_weight(value, text_num_heads)
else:
value = convert_qkv_weight(value, visual_num_heads)
if "attn.in_proj_bias" in key:
key = key.replace("attn.in_proj_bias", "attention.query_key_value.bias")
if "visual" not in key:
value = convert_qkv_bias(value, text_num_heads)
else:
value = convert_qkv_bias(value, visual_num_heads)
# change "mlp.c_fc" to "mlp.dense_h_to_4h"
if "mlp.c_fc" in key:
key = key.replace("mlp.c_fc", "mlp.dense_h_to_4h")
# change "mlp.c_proj" to "mlp.dense_4h_to_h"
if "mlp.c_proj" in key:
key = key.replace("mlp.c_proj", "mlp.dense_4h_to_h")
# change "class_embedding" to "cls_token"
if "class_embedding" in key:
key = key.replace("class_embedding", "cls_token")
value = value.unsqueeze(0).unsqueeze(0)
# change "pos_embed" to "positional_embedding"
if "visual.positional_embedding" == key:
key = "visual.pos_embed"
value = value.unsqueeze(0)
# change patch_embedding
if key == "visual.conv1.weight":
key = "visual.patch_embed.proj.weight"
# change "ln_post"
if "ln_post" in key:
key = key.replace("ln_post", "norm")
# change "proj"
if "visual.proj" == key:
key = "visual.head.weight"
value = value.transpose(0, 1)
# added by huangwei
key = key.replace("attention.query_key_value", "self_attention.query_key_value").replace(
"attention.dense", "self_attention.dense"
)
new_state_dict[key] = value
return new_state_dict
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len(
[
k
for k in state_dict.keys()
if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
]
)
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}")))
for b in [1, 2, 3, 4]
]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round(
(state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))
)
if vit:
state_dict = change_vit_state_dict(state_dict, vision_width // 64, transformer_heads)
model = CLIP(
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
# convert_weights(model)
load_weights(model, state_dict)
return model.eval()
# --------------------------------------------------------
# Reference:
# https://github.com/pyflow/pyflow/blob/1.7/flow/nn/functional.py#L4041
# --------------------------------------------------------
import warnings
from typing import Optional, Tuple
import oneflow as flow
import oneflow.nn.functional as F
from oneflow import Tensor
def multi_head_attention_forward(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
embed_dim_to_check: total dimension of the model.
num_heads: parallel attention heads.
in_proj_weight, in_proj_bias: input projection weight and bias.
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
dropout_p: probability of an element to be zeroed.
out_proj_weight, out_proj_bias: the output projection weight and bias.
training: apply dropout if is ``True``.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. This is an binary mask. When the value is True,
the corresponding value on the attention layer will be filled with -inf.
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions.
A 2D mask will be broadcasted for all the batches while a 3D mask allows
to specify a different mask for the entries of each batch.
use_separate_proj_weight: the function accept the proj. weights for query, key,
and value in different forms. If false, in_proj_weight will be used, which is
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
static_k, static_v: static key and value used for attention operators.
Shape:
Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source
sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while
the zero positions will be unchanged. If a BoolTensor is provided, the positions
with the value of ``True`` will be ignored while the position with the value
of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length,
S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target
sequence length, S is the source sequence length. attn_mask ensures that position
i is allowed to attend the unmasked positions.
If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions
with ``True`` are not allowed to attend while ``False`` values will be unchanged.
If a FloatTensor is provided, it will be added to the attention weight.
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension.
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
L is the target sequence length, S is the source sequence length.
"""
tgt_len, bsz, embed_dim = query.size()
assert embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = embed_dim // num_heads
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
if not use_separate_proj_weight:
if flow.equal(query, key) and flow.equal(key, value):
# self-attention
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
elif flow.equal(key, value):
# encoder-decoder attention
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
if key is None:
assert value is None
k = None
v = None
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
else:
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = 0
_end = embed_dim
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
q = F.linear(query, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
_end = embed_dim * 2
_w = in_proj_weight[_start:_end, :]
if _b is not None:
_b = _b[_start:_end]
k = F.linear(key, _w, _b)
# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim * 2
_end = None
_w = in_proj_weight[_start:, :]
if _b is not None:
_b = _b[_start:]
v = F.linear(value, _w, _b)
else:
len1, len2 = q_proj_weight.size()
assert len1 == embed_dim and len2 == query.size(-1)
len1, len2 = k_proj_weight.size()
assert len1 == embed_dim and len2 == key.size(-1)
len1, len2 = v_proj_weight.size()
assert len1 == embed_dim and len2 == value.size(-1)
if in_proj_bias is not None:
q = F.linear(query, q_proj_weight, in_proj_bias[0:embed_dim])
k = F.linear(key, k_proj_weight, in_proj_bias[embed_dim : (embed_dim * 2)])
v = F.linear(value, v_proj_weight, in_proj_bias[(embed_dim * 2) :])
else:
q = F.linear(query, q_proj_weight, in_proj_bias)
k = F.linear(key, k_proj_weight, in_proj_bias)
v = F.linear(value, v_proj_weight, in_proj_bias)
q = q * scaling
if attn_mask is not None:
assert (
attn_mask.dtype == flow.float32
or attn_mask.dtype == flow.float64
or attn_mask.dtype == flow.float16
or attn_mask.dtype == flow.uint8
or attn_mask.dtype == flow.bool
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
attn_mask.dtype
)
if attn_mask.dtype == flow.uint8:
warnings.warn(
"Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. "
"Use bool tensor instead."
)
attn_mask = attn_mask.to(flow.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
else:
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == flow.uint8:
warnings.warn(
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. "
"Use bool tensor instead."
)
key_padding_mask = key_padding_mask.to(flow.bool)
assert bias_k is None, "Only support bias_k is None"
assert bias_v is None, "Only support bias_v is None"
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_output_weights = flow.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
if attn_mask is not None:
if attn_mask.dtype == flow.bool:
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float("-inf"),
)
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
attn_output_weights = F.softmax(attn_output_weights, dim=-1)
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
attn_output = flow.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
return attn_output, attn_output_weights.sum(dim=1) / num_heads
else:
return attn_output, None
# --------------------------------------------------------
# Borrow code from:
# https://github.com/openai/CLIP/tree/main/clip/simple_tokenizer.py
# --------------------------------------------------------
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
from libai.utils.file_utils import download_file
@lru_cache()
def default_bpe():
default_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
)
if not os.path.exists(default_path):
download_file(
default_path,
"https://oneflow-static.oss-cn-beijing.aliyuncs.com/libai/clip/bpe_simple_vocab_16e6.txt.gz", # noqa: E501
)
return default_path
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you
want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing
around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
merges = merges[1 : 49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + "</w>" for v in vocab]
for merge in merges:
vocab.append("".join(merge))
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {"<|startoftext|>": "<|startoftext|>", "<|endoftext|>": "<|endoftext|>"}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", # noqa
re.IGNORECASE,
)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + "</w>",)
pairs = get_pairs(word)
if not pairs:
return token + "</w>"
while True:
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except: # noqa
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = " ".join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
return bpe_tokens
def decode(self, tokens):
text = "".join([self.decoder[token] for token in tokens])
text = (
bytearray([self.byte_decoder[c] for c in text])
.decode("utf-8", errors="replace")
.replace("</w>", " ")
)
return text
ftfy
regex
tqdm
oneflow
flowvision
torch
torchvision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment