Commit 037a1c83 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #811 failed with stages
in 0 seconds
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma model config."""
import dataclasses
import immutabledict
import torch
from typing import Optional
# Keep a mapping from dtype strings to the supported torch dtypes.
_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({
'float16': torch.float16,
'float': torch.float32,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
})
@dataclasses.dataclass
class GemmaConfig:
# The number of tokens in the vocabulary.
vocab_size: int = 256000
# The maximum sequence length that this model might ever be used with.
max_position_embeddings: int = 8192
# The number of blocks in the model.
num_hidden_layers: int = 28
# The number of attention heads used in the attention layers of the model.
num_attention_heads: int = 16
# The number of key-value heads for implementing attention.
num_key_value_heads: int = 16
# The hidden size of the model.
hidden_size: int = 3072
# The dimension of the MLP representations.
intermediate_size: int = 24576
# The number of head dimensions.
head_dim: int = 256
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# The dtype of the weights.
dtype: str = 'bfloat16'
# Whether a quantized version of the model is used.
quant: bool = False
# The path to the model tokenizer.
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'
def get_dtype(self) -> Optional[torch.dtype]:
"""Gets the torch dtype from the config dtype string."""
return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)
def get_config_for_7b() -> GemmaConfig:
return GemmaConfig()
def get_config_for_2b() -> GemmaConfig:
return GemmaConfig(
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
hidden_size=2048,
intermediate_size=16384
)
def get_model_config(variant: str) -> GemmaConfig:
if variant == '7b':
return get_config_for_7b()
elif variant == '2b':
return get_config_for_2b()
return ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
'and "7b"')
This diff is collapsed.
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model implementation."""
import re
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, List, Optional, Sequence, Tuple, Union
from gemma import config as gemma_config
from gemma.xla_model_parallel import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
reduce_from_model_parallel_region,
scatter_to_model_parallel_region,
)
class Sampler(nn.Module):
def __init__(self, vocab_size: int, world_size: int, rank: int) -> None:
super().__init__()
self.vocab_size = vocab_size
self.world_size = world_size
self.rank = rank
@torch.no_grad()
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
output_positions: torch.Tensor,
temperatures: torch.Tensor,
top_ps: torch.Tensor,
top_ks: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Select the last element for each sequence.
# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
hidden_states = hidden_states.index_select(
1, output_positions).squeeze(dim=1)
hidden_states_parallel = scatter_to_model_parallel_region(
hidden_states,
groups=None,
world_size=self.world_size,
rank=self.rank)
hidden_states_parallel = torch.matmul(hidden_states_parallel,
embedding.t())
logits = reduce_from_model_parallel_region(
hidden_states_parallel,
groups=None,
world_size=self.world_size,
rank=self.rank,
)
if embedding_bias is not None:
logits += embedding_bias
if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))
# Calculate probabilities with softmax.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# Apply top-p, top-k.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)
top_ks_mask = torch.arange(probs_idx.shape[-1],
device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
x = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
output = x * (1 + self.weight)
else:
output = x * self.weight
return output
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
world_size: int,
rank: int,
quant: bool,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
def init_method(x):
return x
self.gate_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs
class GemmaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
world_size: int,
rank: int,
quant: bool,
):
super().__init__()
self.rank = rank
def init_method(x):
return x
self.total_num_heads = num_heads
assert self.total_num_heads % world_size == 0
self.num_heads = self.total_num_heads // world_size # head per shard
if num_kv_heads < world_size:
assert world_size % num_kv_heads == 0
self.total_num_kv_heads = world_size
else:
assert num_kv_heads % world_size == 0
self.total_num_kv_heads = num_kv_heads
self.num_kv_heads = self.total_num_kv_heads // world_size # kv head per shard
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3
batch_size, input_len, _ = hidden_states_shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
key = k_cache
value = v_cache
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=2)
# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)
# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)
# [batch_size, input_len, hidden_dim]
output = (output.transpose(1, 2).contiguous().view(
batch_size, input_len, -1))
output = self.o_proj(output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
world_size: int,
rank: int,
):
super().__init__()
self.rank = rank
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
world_size=world_size,
rank=rank,
quant=config.quant,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
world_size=world_size,
rank=rank,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class GemmaModel(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
world_size: int,
rank: int
):
super().__init__()
self.config = config
self.rank = rank
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(GemmaDecoderLayer(config, world_size, rank))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
) -> torch.Tensor:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_caches[i],
mask=mask,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
world_size: int,
rank: int,
device: torch.device,
):
super().__init__()
self.config = config
self.world_size = world_size
self.rank = rank
self.device = device
assert config.num_attention_heads % world_size == 0
assert config.hidden_size % config.num_attention_heads == 0
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size
def init_method(x):
return x
self.embedder = ParallelEmbedding(
vocab_size,
config.hidden_size,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=config.quant,
)
self.model = GemmaModel(config, world_size, rank)
self.sampler = Sampler(vocab_size, world_size, rank)
rope_theta = getattr(config, 'rope_theta', 10000)
# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
freqs_cis = precompute_freqs_cis(head_dim,
max_seq_len * 2,
theta=rope_theta)
self.register_buffer('freqs_cis', freqs_cis)
@torch.no_grad()
def forward(
self,
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
temperatures: torch.Tensor,
top_ps: torch.Tensor,
top_ks: torch.Tensor,
**kwargs,
) -> torch.Tensor:
freqs_cis = self.freqs_cis.index_select(0, input_positions)
kv_write_indices = input_positions
hidden_states = self.embedder(input_token_ids)
# Gemma normalizes the embedding by sqrt(hidden_size).
hidden_states = hidden_states * (self.config.hidden_size**0.5)
# hidden_states should be [batch_size, input_len, hidden_size]
hidden_states = self.model(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_caches=kv_caches,
mask=mask,
)
embedder_weight = self.embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
next_tokens = self.sampler(
embedding=embedder_weight,
hidden_states=hidden_states,
output_positions=output_positions,
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
)
return next_tokens
def load_weights(self, model_path: str):
checkpoint = torch.load(model_path, weights_only=True)
model_state_dict = checkpoint['model_state_dict']
num_attn_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.config.head_dim
hidden_size = self.config.hidden_size
def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
axis_len = tensor.shape[axis]
split_len = axis_len // self.world_size
split_start = split_len * self.rank
split_end = split_start + split_len
tensor = torch.moveaxis(tensor, axis, 0)
tensor = tensor[split_start:split_end, ...]
tensor = torch.moveaxis(tensor, 0, axis)
return tensor
for k, v in model_state_dict.items():
if k == 'freqs_cis':
continue
if (k == 'model.norm.weight' or re.fullmatch(
r'model.layers.\d+.input_layernorm.weight', k)
or re.fullmatch(
r'model.layers.\d+.post_attention_layernorm.weight',
k) or k.endswith('weight_scaler')):
pass
elif (k == 'embedder.weight' or re.fullmatch(
r'model.layers.\d+.mlp.down_proj.weight', k)):
v = split(v, 1)
elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)
or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):
v = split(v, 0)
elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',
k):
if num_kv_heads <= self.world_size:
num_replicas = self.world_size // num_kv_heads
v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,
hidden_size)
query = v[:num_attn_heads, ...]
key = v[num_attn_heads:num_attn_heads + num_kv_heads,
...].repeat(num_replicas, 1, 1)
value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)
v = torch.cat(
(split(query, 0), split(key, 0), split(value, 0)),
dim=0)
else:
v = v.reshape(3, num_attn_heads, head_dim, hidden_size)
v = split(v, 1)
v = v.reshape(-1, hidden_size)
elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):
v = v.reshape(hidden_size, num_attn_heads, head_dim)
v = split(v, 1)
v = v.reshape(hidden_size, -1)
else:
raise ValueError(f'Unrecognized key: {k}')
self.state_dict()[k].copy_(v)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Optional
from sentencepiece import SentencePieceProcessor
class Tokenizer:
def __init__(self, model_path: Optional[str]):
# Reload tokenizer.
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs.
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
"""Converts a string into a list of tokens."""
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
"""Converts a list of tokens into a string."""
return self.sp_model.decode(t)
This diff is collapsed.
#/bin/bash
python scripts/run.py --ckpt="gemma-2b-pytorch/gemma-2b-it.ckpt" --variant=2b --prompt="The meaning of life is" --device=cuda
# 模型编码
modelCode=560
# 模型名称
modelName=gemma_pytorch
# 模型描述
modelDescription=谷歌发布的号称“全球性能最强大、轻量级”的新一代开源2B小模型Gemma,打响小模型战争。
# 应用场景
appScenario=推理,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
fairscale == 0.4.13
numpy == 1.24.4
immutabledict == 4.1.0
sentencepiece == 0.1.99
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import random
import numpy as np
import torch
from gemma import config
from gemma import model as gemma_model
import time
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def main(args):
# Construct the model config.
model_config = config.get_model_config(args.variant)
model_config.dtype = "float32" if args.device == "cpu" else "float16"
model_config.quant = args.quant
# Seed random.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Create the model and load the weights.
device = torch.device(args.device)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(args.ckpt)
model = model.to(device).eval()
print("Model loading done")
# Generate the response.
# start_time = time.time()
result = model.generate(args.prompt, device, output_len=args.output_len)
# print("infer time:", time.time() - start_time, "s")
# Print the prompts and results.
print('======================================')
print(f'PROMPT: {args.prompt}')
print(f'RESULT: {result}')
print('======================================')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "7b"])
parser.add_argument("--device",
type=str,
default="cpu",
choices=["cpu", "cuda"])
parser.add_argument("--output_len", type=int, default=100)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import random
import socket
import sys
from typing import List
import numpy as np
import torch
import torch.multiprocessing
from gemma.config import GemmaConfig, get_model_config
from gemma.model_xla import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import gemma.xla_model_parallel as xla_model_parallel
USE_CUDA = os.environ.get('USE_CUDA', False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
# Choose an available port.
with contextlib.closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
MASTER_PORT = str(s.getsockname()[1])
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def generate(i: int, model_config: GemmaConfig, ckpt_path: str,
prompts: List[str], output_lens: List[int],
temperatures: List[float], top_ps: List[float],
top_ks: List[int], seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device("cuda", local_rank)
torch.cuda.set_device(local_rank)
else:
device = xm.xla_device()
xm.set_rng_state(seed, device)
rank = xla_model_parallel.get_model_parallel_rank()
world_size = xla_model_parallel.get_model_parallel_world_size()
if rank > 0:
sys.stdout = open(os.devnull, 'w')
# build, load and compile model.
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config, world_size, rank, device)
model.load_weights(ckpt_path)
model = model.to(device).eval()
# create tokenizer.
tokenizer = Tokenizer(model_config.tokenizer)
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
batch_size = len(prompts)
assert batch_size == len(temperatures)
assert batch_size == len(top_ps)
assert batch_size == len(top_ks)
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
assert max_seq_len <= model_config.max_position_embeddings
if model_config.num_key_value_heads < world_size:
assert world_size % model_config.num_key_value_heads == 0
n_local_heads = 1
else:
assert model_config.num_key_value_heads % world_size == 0
n_local_heads = model_config.num_key_value_heads // world_size
# build KV caches
kv_caches = []
for _ in range(model_config.num_hidden_layers):
k_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
v_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
tokenizer.pad_id,
dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
input_token_ids_tensor = input_token_ids_tensor.to(device)
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
temperatures_tensor = torch.FloatTensor(temperatures).to(device)
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
top_ks_tensor = torch.LongTensor(top_ks).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
if not USE_CUDA:
xm.mark_step()
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids = model(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
output_index = output_index + 1
if not USE_CUDA:
xm.mark_step()
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
output_lens[i]]
if tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(tokenizer.decode(trimmed_output))
for prompt, result in zip(prompts, results):
print('======================================')
print(f'PROMPT: {prompt}')
print(f'RESULT: {result}')
print('======================================')
def main(args):
model_config = get_model_config(args.variant)
model_config.quant = args.quant
prompts = [args.prompt]
n = len(prompts)
output_lengths = [args.output_len] * n
temperatures = [0.95] * n
top_ps = [1.0] * n
top_ks = [100] * n
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
torch.multiprocessing.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
else:
xmp.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "7b"])
parser.add_argument("--output_len", type=int, default=4)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from typing import List
import setuptools
ROOT_DIR = os.path.dirname(__file__)
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def read_readme() -> str:
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements
setuptools.setup(
name="gemma",
version="0.1",
author="Gemma contributors",
license="Apache 2.0",
description=("Gemma model implementation"),
long_description=read_readme(),
long_description_content_type="text/markdown",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(exclude=("benchmarks", "docs",
"examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment