Commit 2384a2ca authored by chenxj's avatar chenxj
Browse files

initial commit

parents
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Tuple
from aitemplate.compiler import ops
from aitemplate.frontend import nn, Tensor
from aitemplate.testing import detect_target
# pylint: disable=W0102
USE_CUDA = detect_target().name() == "cuda"
class BertSelfOutput(nn.Module):
def __init__(self, hidden_size, layer_norm_eps):
"""dense + add is included in nn.MultiheadAttention.
This class now only contains LayerNorm.
"""
super().__init__()
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: Tensor) -> Tensor:
if not USE_CUDA:
hidden_states = (
hidden_states
if hidden_states._rank() == 2
else ops.reshape()(hidden_states, [-1, hidden_states._size(-1)])
)
# [B, S, H] on cuda, [B * S, H] on rocm
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertAttention(nn.Module):
def __init__(
self,
batch_size,
seq_len,
hidden_size,
num_attention_heads,
layer_norm_eps,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
):
super().__init__()
self.self = nn.MultiheadAttention(
dim=hidden_size,
batch_size=batch_size,
seq_len=seq_len,
num_heads=num_attention_heads,
qkv_bias=True,
attn_drop=attention_probs_dropout_prob,
proj_drop=hidden_dropout_prob,
has_residual=True,
)
self.output = BertSelfOutput(hidden_size, layer_norm_eps)
def forward(
self,
hidden_states: Tensor,
) -> Tuple[Tensor]:
self_output = self.self(hidden_states, hidden_states)
attention_output = self.output(self_output)
outputs = (attention_output,)
return outputs
# FFN block
class BertIntermediate(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act):
super().__init__()
# dense + activation
self.dense = nn.Linear(
hidden_size, intermediate_size, specialization=hidden_act
)
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(
self, hidden_size, intermediate_size, layer_norm_eps, hidden_dropout_prob
):
super().__init__()
assert hidden_dropout_prob == 0.0
# dense + add
self.dense = nn.Linear(intermediate_size, hidden_size, specialization="add")
self.dropout = nn.Dropout(hidden_dropout_prob)
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states, input_tensor)
# hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLayer(nn.Module):
def __init__(
self,
hidden_size,
batch_size,
seq_len,
num_attention_heads,
intermediate_size,
hidden_act,
layer_norm_eps,
attention_probs_dropout_prob,
hidden_dropout_prob,
):
super().__init__()
self.attention = BertAttention(
batch_size=batch_size,
seq_len=seq_len,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
layer_norm_eps=layer_norm_eps,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
)
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act)
self.output = BertOutput(
hidden_size, intermediate_size, layer_norm_eps, hidden_dropout_prob
)
def feed_forward(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def forward(
self,
hidden_states: Tensor,
):
# [B, S, H]
shape = hidden_states.shape()
# [B, S, H] on cuda, [B * S, H] on rocm
self_attention_outputs = self.attention(hidden_states)
layer_output = self.feed_forward(self_attention_outputs[0])
# [B * S, H] to [B, S, H] on rocm
layer_output = (
layer_output
if layer_output._rank() == 3
else ops.reshape()(layer_output, shape)
)
return (layer_output,)
class BertEncoder(nn.Module):
def __init__(
self,
num_hidden_layers,
hidden_size,
batch_size,
seq_len,
num_attention_heads,
intermediate_size,
hidden_act,
layer_norm_eps,
attention_probs_dropout_prob,
hidden_dropout_prob,
):
super().__init__()
self.layer = nn.ModuleList(
[
BertLayer(
batch_size=batch_size,
seq_len=seq_len,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
)
for _ in range(num_hidden_layers)
]
)
def forward(
self,
hidden_states: Tensor,
):
for layer_module in self.layer:
layer_outputs = layer_module(hidden_states)
hidden_states = layer_outputs[0]
return layer_outputs
class BertModel(nn.Module):
def __init__(
self,
batch_size,
seq_len,
vocab_size,
max_position_embeddings,
type_vocab_size,
num_hidden_layers,
hidden_size,
num_attention_heads,
intermediate_size,
hidden_act,
layer_norm_eps,
attention_probs_dropout_prob,
hidden_dropout_prob,
add_pooling_layer=False,
):
super().__init__()
assert not add_pooling_layer
self.embeddings = nn.BertEmbeddings(
hidden_size=hidden_size,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
)
self.encoder = BertEncoder(
batch_size=batch_size,
seq_len=seq_len,
num_hidden_layers=num_hidden_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
)
def forward(
self,
input_ids: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
):
embedding_output = self.embeddings(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
encoder_outputs = self.encoder(
embedding_output,
)
return encoder_outputs
class BertModelEncodersOnly(nn.Module):
def __init__(
self,
batch_size,
seq_len,
num_hidden_layers,
hidden_size,
num_attention_heads,
intermediate_size,
hidden_act,
layer_norm_eps,
attention_probs_dropout_prob,
hidden_dropout_prob,
add_pooling_layer=False,
):
super().__init__()
assert not add_pooling_layer
self.encoder = BertEncoder(
batch_size=batch_size,
seq_len=seq_len,
num_hidden_layers=num_hidden_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
)
def forward(
self,
encoder_input: Tensor,
):
encoder_outputs = self.encoder(encoder_input)
return encoder_outputs
class BertBaseUncased(nn.Module):
"""Bert base uncased with no classification head."""
def __init__(
self,
batch_size,
seq_len,
vocab_size=30522,
max_position_embeddings=512,
type_vocab_size=2,
num_hidden_layers=12,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
layer_norm_eps=1e-12,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
):
super().__init__()
self.bert = BertModel(
batch_size=batch_size,
seq_len=seq_len,
vocab_size=vocab_size,
max_position_embeddings=max_position_embeddings,
type_vocab_size=type_vocab_size,
num_hidden_layers=num_hidden_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
add_pooling_layer=False,
)
def forward(
self,
input_ids: Tensor,
token_type_ids: Tensor,
position_ids: Tensor,
) -> Tensor:
outputs = self.bert(
input_ids=input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
return outputs
class BertBaseEncodersOnly(nn.Module):
"""Bert base uncased with no classification head and no embeddings."""
def __init__(
self,
batch_size,
seq_len,
num_hidden_layers=12,
hidden_size=768,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
layer_norm_eps=1e-12,
attention_probs_dropout_prob=0.0,
hidden_dropout_prob=0.0,
):
super().__init__()
self.bert = BertModelEncodersOnly(
batch_size=batch_size,
seq_len=seq_len,
num_hidden_layers=num_hidden_layers,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
layer_norm_eps=layer_norm_eps,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
add_pooling_layer=False,
)
def forward(
self,
encoder_input: Tensor,
) -> Tensor:
outputs = self.bert(encoder_input)
return outputs
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from transformers import AutoModelForMaskedLM, BertForMaskedLM
class BertBaseUncased:
def __init__(self, model_path="bert-base-uncased", pretrained=True):
if not pretrained:
pretrained = AutoModelForMaskedLM.from_pretrained(model_path)
self._model = BertForMaskedLM(pretrained.config).cuda().half()
else:
self._model = AutoModelForMaskedLM.from_pretrained(model_path).cuda().half()
self._vocab_size = 30522
def forward(self, *args, **kwargs):
# runs the full model with classification head
outputs = self._model(*args, **kwargs)
return outputs.logits
def generate_inputs(self, batch_size, seq_len):
dtype = torch.long
input_ids = torch.randint(
0, self._vocab_size, (batch_size, seq_len), dtype=dtype
).cuda()
token_type_ids = torch.zeros(input_ids.size(), dtype=dtype).cuda()
position_ids = (
torch.arange(seq_len, dtype=dtype)
.reshape((1, -1))
.expand(batch_size, -1)
.contiguous()
.cuda()
)
return (input_ids, token_type_ids, position_ids)
def get_parameters(self):
return dict(self._model.named_parameters())
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""benchmark for vit"""
import os
import click
import numpy as np
import torch
from aitemplate.compiler import compile_model, Model
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from modeling.vision_transformer import VisionTransformer
from weight_utils import export_to_torch_tensor
# flake8: noqa
def mark_output(y):
if type(y) is not tuple:
y = (y,)
for i in range(len(y)):
y[i]._attrs["is_output"] = True
y[i]._attrs["name"] = "output_%d" % (i)
y_shape = [d._attrs["values"][0] for d in y[i]._attrs["shape"]]
print("output_{} shape: {}".format(i, y_shape))
USE_CUDA = detect_target().name() == "cuda"
def compile_vit(
model_name,
batch_size,
class_token=False,
global_pool="avg",
use_fp16_acc=True,
):
img_size = 224
patch_size = 16
embed_dim = 768
num_heads = 12
depth = 12
if model_name == "vit_base_patch16_224":
img_size = 224
patch_size = 16
embed_dim = 768
num_heads = 12
depth = 12
elif model_name == "vit_large_patch16_384":
img_size = 384
patch_size = 16
embed_dim = 1024
num_heads = 16
depth = 24
seqlen = (img_size // patch_size) ** 2 + (1 if class_token else 0)
ait_model = VisionTransformer(
batch_size=batch_size,
img_size=img_size,
class_token=class_token,
global_pool=global_pool,
num_heads=num_heads,
embed_dim=embed_dim,
patch_size=patch_size,
depth=depth,
act_layer="GELU",
)
ait_model.name_parameter_tensor()
inputs_ait = Tensor(
[batch_size, img_size, img_size, 3], name="input0", is_input=True
)
Y = ait_model(inputs_ait)
mark_output(Y)
target = detect_target(use_fp16_acc=use_fp16_acc)
exe_module = compile_model(
Y, target, "./tmp", "vision_transformer_bs%d_seq%d" % (batch_size, seqlen), profile_devs=[0,1,2,3]
)
return exe_module
def benchmark(model_name, batch_size, params_ait, mod=None, graph_mode=True):
# load mod
if model_name == "vit_base_patch16_224":
img_size = 224
patch_size = 16
embed_dim = 768
num_heads = 12
depth = 12
elif model_name == "vit_large_patch16_384":
img_size = 384
patch_size = 16
embed_dim = 1024
num_heads = 16
depth = 24
else:
raise NotImplementedError
seqlen = (img_size // patch_size) ** 2 + 1
if mod is None:
model_dir = f"vision_transformer_bs{batch_size}_seq{seqlen}"
mod = Model(os.path.join("./tmp", model_dir, "test.so"))
# prepare params
params_ait["cls_token_mask"] = torch.zeros((batch_size, 1, embed_dim)).cuda().half()
params_ait["fc_norm_weight"] = params_ait["norm_weight"]
params_ait["fc_norm_bias"] = params_ait["norm_bias"]
if detect_target().name() == "cuda":
ait_key = "attn_cu_length"
for i in range(depth):
prefix = "blocks_%d" % (i)
cu_len = np.cumsum([0] + [seqlen] * batch_size).astype("int32")
params_ait[f"{prefix}_{ait_key}"] = torch.from_numpy(cu_len).cuda()
# set weights
mod.set_many_constants_with_tensors(params_ait)
mod.fold_constants(sync=True)
# prepare input/output tensor
inputs = [torch.randn([batch_size, img_size, img_size, 3]).cuda().half()]
ys = []
num_outputs = len(mod.get_output_name_to_index_map())
for i in range(num_outputs):
shape = mod.get_output_maximum_shape(i)
ys.append(torch.empty(shape).cuda().half())
# warm up
t, _, __ = mod.benchmark_with_tensors(
inputs,
ys,
count=100,
repeat=4,
graph_mode=graph_mode,
)
# benchmark
t, _, __ = mod.benchmark_with_tensors(
inputs,
ys,
count=100,
repeat=4,
graph_mode=graph_mode,
)
print(f"batch_size: {batch_size}, latency: {t}")
dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1")
dev_flag = dev_flag.replace(",", "_")
with open(f"{model_name}_ait_benchmark_dev_{dev_flag}.txt", "a") as f:
f.write(f"batch_size: {batch_size}, latency: {t}\n")
@click.command()
@click.option("--model-name", type=str, default="vit_base_patch16_224")
@click.option(
"--use-fp16-acc",
type=bool,
default=True,
help="Whether to use FP16 for accumulation (similar to TensorRT)",
)
@click.option("--use-graph", type=bool, default=True, help="Whether to use CUDA graph")
@click.option("--batch-size", type=int, default=0, help="Batch size")
def main(
model_name="vit_base_patch16_224", use_fp16_acc=True, use_graph=True, batch_size=0
):
if detect_target().name() == "rocm":
use_graph = False
if model_name == "vit_base_patch16_224":
pretrained_path = "./vit_base_patch16_224.augreg2_in21k_ft_in1k/pytorch_model.bin"
elif model_name == "vit_large_patch16_384":
pretrained_path = "./vit_large_patch16_384.augreg_in21k_ft_in1k/pytorch_model.bin"
else:
raise NotImplementedError
params_ait = export_to_torch_tensor(model_name, model_path=pretrained_path, pretrained=True)
if batch_size < 1:
for bs in (1, 2, 4, 8, 16, 32, 64, 128, 256):
compile_vit(model_name, bs, class_token=True, use_fp16_acc=use_fp16_acc)
benchmark(model_name, bs, params_ait, graph_mode=use_graph)
else:
benchmark(model_name, batch_size, params_ait, graph_mode=use_graph)
if __name__ == "__main__":
main()
#!/bin/bash
#profile
HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 benchmark_ait.py
#1GCD
HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1"
#2GCD
HIP_VISIBLE_DEVICES=0 python3 benchmark_ait.py --batch-size "$1" &
HIP_VISIBLE_DEVICES=1 python3 benchmark_ait.py --batch-size "$1" && fg
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import click
import torch
from aitemplate.testing.benchmark_pt import benchmark_torch_function
from timm.models.vision_transformer import vit_base_patch16_224, vit_large_patch16_384
def create_vit(model_name):
if model_name == "vit_base_patch16_224":
model_path = "./vit_base_patch16_224.augreg2_in21k_ft_in1k/pytorch_model.bin"
model = vit_base_patch16_224(pretrained=True, pretrained_path=model_path).cuda().half()
elif model_name == "vit_large_patch16_384":
model_path = "./vit_large_patch16_384.augreg_in21k_ft_in1k/pytorch_model.bin"
model = vit_large_patch16_384(pretrained=True, pretrained_path=model_path).cuda().half()
return model
def benchmark(model_name, batch_size, img_size, model):
if model_name == "vit_base_patch16_224":
img_size = 224
elif model_name == "vit_large_patch16_384":
img_size = 384
with torch.inference_mode():
input_shape = (batch_size, 3, img_size, img_size)
input_data = torch.randn(input_shape).cuda().half()
# warm up
benchmark_torch_function(100, model, input_data)
# benchmark
t = benchmark_torch_function(100, model, input_data)
print("batch_size: {}, time: {}".format(batch_size, t))
dev_flag = os.environ.get("HIP_VISIBLE_DEVICES", "-1")
dev_flag = dev_flag.replace(",", "_")
with open(f"{model_name}_pt_benchmark_dev_{dev_flag}.txt", "a") as f:
f.write("batch_size: {}, latency: {}\n".format(batch_size, t))
@click.command()
@click.option("--model-name", type=str, default="vit_base_patch16_224")
@click.option("--batch-size", default=0, type=int)
def main(model_name, batch_size):
img_size = 224
if model_name == "vit_base_patch16_224":
img_size = 224
elif model_name == "vit_large_patch16_384":
img_size = 384
else:
raise NotImplementedError
model = create_vit(model_name)
if batch_size == 0:
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256]:
benchmark(model_name, batch_size, img_size, model)
else:
benchmark(model_name, batch_size, img_size, model)
if __name__ == "__main__":
main()
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from functools import partial
from aitemplate.compiler import ops
from aitemplate.frontend import nn
from aitemplate.testing import detect_target
# pylint: disable=W0102
USE_CUDA = detect_target().name() == "cuda"
def get_shape(x):
shape = [it.value() for it in x._attrs["shape"]]
return shape
class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer="GELU",
drop=0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(
in_features,
hidden_features,
specialization="fast_gelu" if act_layer == "GELU" else "relu",
)
self.fc2 = nn.Linear(hidden_features, out_features, specialization="add")
def forward(self, x, res):
shape = get_shape(x)
x = self.fc1(x)
x = self.fc2(x, res)
return ops.reshape()(x, shape)
class Block(nn.Module):
def __init__(
self,
dim,
batch_size,
seq_len,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
init_values=None,
drop_path=0.0,
act_layer="GELU",
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = nn.MultiheadAttention(
dim,
batch_size,
seq_len,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = nn.DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = nn.Identity()
self.drop_path2 = nn.DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, x):
x = self.attn(self.norm1(x), x)
x = self.mlp(self.norm2(x), x)
return x
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size // patch_size, img_size // patch_size)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.embed_dim = embed_dim
conv_op = (
nn.Conv2dBiasFewChannels
if detect_target().name() == "cuda"
else nn.Conv2dBias
)
self.proj = conv_op(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.proj_norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, H, W, C = get_shape(x)
x = self.proj(x)
if self.flatten:
x = ops.reshape()(x, [B, -1, self.embed_dim])
x = self.proj_norm(x)
return x
class VisionTransformer(nn.Module):
"""Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
batch_size=1,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool="token",
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
fc_norm=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
weight_init="",
embed_layer=PatchEmbed,
norm_layer=nn.LayerNorm,
act_layer=None,
block_fn=Block,
dtype="float16",
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ("", "avg", "token")
assert class_token or global_pool != "token"
use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
num_patches = self.patch_embed.num_patches
self.cls_token = (
nn.Parameter(shape=[1, 1, embed_dim], dtype=dtype) if class_token else None
)
self.cls_token_mask = (
nn.Parameter(shape=[batch_size, 1, embed_dim], dtype=dtype)
if class_token
else None
)
embed_len = (
num_patches if no_embed_class else num_patches + self.num_prefix_tokens
)
self.pos_embed = nn.Parameter(shape=[1, embed_len, embed_dim], dtype=dtype)
self.pos_drop = nn.Dropout(p=drop_rate)
seq_len = (img_size // patch_size) ** 2 + (1 if class_token else 0)
self.pool_size = img_size // patch_size
self.blocks = nn.Sequential(
*[
block_fn(
dim=embed_dim,
batch_size=batch_size,
seq_len=seq_len,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=0,
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
if global_pool == "avg":
self.pool = nn.AvgPool2d(kernel_size=self.pool_size, stride=1, padding=0)
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = (
nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
)
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed.tensor()
if self.cls_token is not None:
cls_token_expand = ops.expand()(
self.cls_token.tensor(), [get_shape(x)[0], -1, -1]
)
cls_token_expand = cls_token_expand + self.cls_token_mask.tensor()
x = ops.concatenate()([cls_token_expand, x], dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
cls_token_expand = ops.expand()(
self.cls_token.tensor(), [get_shape(x)[0], -1, -1]
)
cls_token_expand = cls_token_expand + self.cls_token_mask.tensor()
x = ops.concatenate()([cls_token_expand, x], dim=1)
x = x + self.pos_embed.tensor()
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.blocks(x)
x = self.norm(x)
return x
def _global_pool(self, x):
batch, seq, d = get_shape(x)
x = ops.reshape()(x, [batch, self.pool_size, self.pool_size, d])
y = self.pool(x)
return ops.reshape()(y, [batch, d])
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
if self.global_pool == "avg":
x = self._global_pool(x)
else:
batch, seq, d = get_shape(x)
x = ops.dynamic_slice()(
x, start_indices=[0, 0, 0], end_indices=[batch, 1, d]
)
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
from .version import __version__
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config, resolve_model_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo, CustomDatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .readers import create_reader
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment