"discover/gpu_linux.go" did not exist on "b732beba6a919b852539bb344b05e25c6a7c3c90"
Commit 037a1c83 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #811 failed with stages
in 0 seconds
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma model config."""
import dataclasses
import immutabledict
import torch
from typing import Optional
# Keep a mapping from dtype strings to the supported torch dtypes.
_STR_DTYPE_TO_TORCH_DTYPE = immutabledict.immutabledict({
'float16': torch.float16,
'float': torch.float32,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
})
@dataclasses.dataclass
class GemmaConfig:
# The number of tokens in the vocabulary.
vocab_size: int = 256000
# The maximum sequence length that this model might ever be used with.
max_position_embeddings: int = 8192
# The number of blocks in the model.
num_hidden_layers: int = 28
# The number of attention heads used in the attention layers of the model.
num_attention_heads: int = 16
# The number of key-value heads for implementing attention.
num_key_value_heads: int = 16
# The hidden size of the model.
hidden_size: int = 3072
# The dimension of the MLP representations.
intermediate_size: int = 24576
# The number of head dimensions.
head_dim: int = 256
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# The dtype of the weights.
dtype: str = 'bfloat16'
# Whether a quantized version of the model is used.
quant: bool = False
# The path to the model tokenizer.
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'
def get_dtype(self) -> Optional[torch.dtype]:
"""Gets the torch dtype from the config dtype string."""
return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)
def get_config_for_7b() -> GemmaConfig:
return GemmaConfig()
def get_config_for_2b() -> GemmaConfig:
return GemmaConfig(
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
hidden_size=2048,
intermediate_size=16384
)
def get_model_config(variant: str) -> GemmaConfig:
if variant == '7b':
return get_config_for_7b()
elif variant == '2b':
return get_config_for_2b()
return ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
'and "7b"')
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model implementation."""
import re
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, List, Optional, Sequence, Tuple, Union
from gemma import config as gemma_config
from gemma import tokenizer
class Sampler(nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.vocab_size = vocab_size
@torch.no_grad()
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
output_positions: torch.Tensor,
temperatures: torch.Tensor,
top_ps: torch.Tensor,
top_ks: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Select the last element for each sequence.
# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
hidden_states = hidden_states.index_select(
1, output_positions).squeeze(dim=1)
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))
# Calculate probabilities with softmax.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# Apply top-p, top-k.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)
top_ks_mask = torch.arange(probs_idx.shape[-1],
device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids
def precompute_freqs_cis(dim: int,
end: int,
theta: float = 10000.0) -> torch.Tensor:
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
class Linear(nn.Module):
def __init__(self, in_features: int, out_features: int, quant: bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((out_features, in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
else:
self.weight = nn.Parameter(
torch.empty((out_features, in_features)),
requires_grad=False,
)
self.quant = quant
def forward(self, x):
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.linear(x, weight)
return output
class Embedding(nn.Module):
def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
else:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim)),
requires_grad=False,
)
self.quant = quant
def forward(self, x):
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.embedding(x, weight)
return output
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
x = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
output = x * (1 + self.weight)
else:
output = x * self.weight
return output
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant: bool,
):
super().__init__()
self.gate_proj = Linear(hidden_size, intermediate_size, quant)
self.up_proj = Linear(hidden_size, intermediate_size, quant)
self.down_proj = Linear(intermediate_size, hidden_size, quant)
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs
class GemmaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
quant: bool,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
quant=quant)
self.o_proj = Linear(
self.num_heads * self.head_dim,
self.hidden_size,
quant=quant)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3
batch_size, input_len, _ = hidden_states_shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
key = k_cache
value = v_cache
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=2)
# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)
# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)
# [batch_size, input_len, hidden_dim]
output = (output.transpose(1, 2).contiguous().view(
batch_size, input_len, -1))
output = self.o_proj(output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
quant=config.quant,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class GemmaModel(nn.Module):
def __init__(self, config: gemma_config.GemmaConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(GemmaDecoderLayer(config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
) -> torch.Tensor:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_caches[i],
mask=mask,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.config = config
assert config.hidden_size % config.num_attention_heads == 0
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size
self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
self.model = GemmaModel(config)
self.sampler = Sampler(vocab_size)
# Pre-compute rotary embedding table.
rope_theta = getattr(config, 'rope_theta', 10000)
freqs_cis = precompute_freqs_cis(head_dim,
max_seq_len * 2,
theta=rope_theta)
self.register_buffer('freqs_cis', freqs_cis)
@torch.no_grad()
def forward(
self,
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
temperatures: torch.Tensor,
top_ps: torch.Tensor,
top_ks: torch.Tensor,
**kwargs,
) -> torch.Tensor:
freqs_cis = self.freqs_cis.index_select(0, input_positions)
kv_write_indices = input_positions
# [batch_size, input_len, hidden_size]
hidden_states = self.embedder(input_token_ids)
# Gemma normalizes the embedding by sqrt(hidden_size).
hidden_states = hidden_states * (self.config.hidden_size**0.5)
hidden_states = self.model(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_caches=kv_caches,
mask=mask,
)
embedder_weight = self.embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
next_tokens = self.sampler(
embedding=embedder_weight,
hidden_states=hidden_states,
output_positions=output_positions,
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
)
return next_tokens
def generate(
self,
prompts: Union[str, Sequence[str]],
device: Any,
output_len: int = 100,
temperature: float = 0.95,
top_p: float = 1.0,
top_k: int = 100,
) -> Union[str, Sequence[str]]:
"""Generates responses for given prompts using Gemma model."""
# If a single prompt is provided, treat it as a batch of 1.
is_str_prompt = isinstance(prompts, str)
if is_str_prompt:
prompts = [prompts]
batch_size = len(prompts)
prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
max_prompt_len = max(len(p) for p in prompt_tokens)
max_seq_len = max_prompt_len + output_len
assert max_seq_len <= self.config.max_position_embeddings
# build KV caches
kv_caches = []
for _ in range(self.config.num_hidden_layers):
size = (batch_size, max_seq_len, self.config.num_key_value_heads,
self.config.head_dim)
dtype = self.config.get_dtype()
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
self.tokenizer.pad_id, dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
self.tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
input_token_ids_tensor = input_token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
device)
temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to(
device)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
device)
# Prefill up to min_prompt_len tokens, then treat other prefill as
# decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids = self(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2,
input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
device)
output_index = output_index + 1
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
+ output_len]
if self.tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(self.tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(self.tokenizer.decode(trimmed_output))
# If a string was provided as input, return a string as output.
return results[0] if is_str_prompt else results
def load_weights(self, model_path: str):
self.load_state_dict(
torch.load(
model_path, mmap=True, weights_only=True,
)['model_state_dict'],
strict=False,
)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model implementation."""
import re
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, List, Optional, Sequence, Tuple, Union
from gemma import config as gemma_config
from gemma.xla_model_parallel import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
reduce_from_model_parallel_region,
scatter_to_model_parallel_region,
)
class Sampler(nn.Module):
def __init__(self, vocab_size: int, world_size: int, rank: int) -> None:
super().__init__()
self.vocab_size = vocab_size
self.world_size = world_size
self.rank = rank
@torch.no_grad()
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
output_positions: torch.Tensor,
temperatures: torch.Tensor,
top_ps: torch.Tensor,
top_ks: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Select the last element for each sequence.
# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
hidden_states = hidden_states.index_select(
1, output_positions).squeeze(dim=1)
hidden_states_parallel = scatter_to_model_parallel_region(
hidden_states,
groups=None,
world_size=self.world_size,
rank=self.rank)
hidden_states_parallel = torch.matmul(hidden_states_parallel,
embedding.t())
logits = reduce_from_model_parallel_region(
hidden_states_parallel,
groups=None,
world_size=self.world_size,
rank=self.rank,
)
if embedding_bias is not None:
logits += embedding_bias
if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1)
# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))
# Calculate probabilities with softmax.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
# Apply top-p, top-k.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)
top_ks_mask = torch.arange(probs_idx.shape[-1],
device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)
# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))
next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
x = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
output = x * (1 + self.weight)
else:
output = x * self.weight
return output
class GemmaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
world_size: int,
rank: int,
quant: bool,
):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
def init_method(x):
return x
self.gate_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs
class GemmaAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
world_size: int,
rank: int,
quant: bool,
):
super().__init__()
self.rank = rank
def init_method(x):
return x
self.total_num_heads = num_heads
assert self.total_num_heads % world_size == 0
self.num_heads = self.total_num_heads // world_size # head per shard
if num_kv_heads < world_size:
assert world_size % num_kv_heads == 0
self.total_num_kv_heads = world_size
else:
assert num_kv_heads % world_size == 0
self.total_num_kv_heads = num_kv_heads
self.num_kv_heads = self.total_num_kv_heads // world_size # kv head per shard
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(
self.hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=quant,
)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3
batch_size, input_len, _ = hidden_states_shape
qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)
# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)
# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)
key = k_cache
value = v_cache
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=2)
# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)
# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)
# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)
# [batch_size, input_len, hidden_dim]
output = (output.transpose(1, 2).contiguous().view(
batch_size, input_len, -1))
output = self.o_proj(output)
return output
class GemmaDecoderLayer(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
world_size: int,
rank: int,
):
super().__init__()
self.rank = rank
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
world_size=world_size,
rank=rank,
quant=config.quant,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
world_size=world_size,
rank=rank,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = residual + hidden_states
# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class GemmaModel(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
world_size: int,
rank: int
):
super().__init__()
self.config = config
self.rank = rank
self.vocab_size = config.vocab_size
self.layers = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(GemmaDecoderLayer(config, world_size, rank))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
) -> torch.Tensor:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_caches[i],
mask=mask,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class GemmaForCausalLM(nn.Module):
def __init__(
self,
config: gemma_config.GemmaConfig,
world_size: int,
rank: int,
device: torch.device,
):
super().__init__()
self.config = config
self.world_size = world_size
self.rank = rank
self.device = device
assert config.num_attention_heads % world_size == 0
assert config.hidden_size % config.num_attention_heads == 0
max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size
def init_method(x):
return x
self.embedder = ParallelEmbedding(
vocab_size,
config.hidden_size,
init_method=init_method,
world_size=world_size,
rank=rank,
quant=config.quant,
)
self.model = GemmaModel(config, world_size, rank)
self.sampler = Sampler(vocab_size, world_size, rank)
rope_theta = getattr(config, 'rope_theta', 10000)
# [head_dim * 2, ] -> complex -> two dim (real, imaginary) implicitly
freqs_cis = precompute_freqs_cis(head_dim,
max_seq_len * 2,
theta=rope_theta)
self.register_buffer('freqs_cis', freqs_cis)
@torch.no_grad()
def forward(
self,
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
temperatures: torch.Tensor,
top_ps: torch.Tensor,
top_ks: torch.Tensor,
**kwargs,
) -> torch.Tensor:
freqs_cis = self.freqs_cis.index_select(0, input_positions)
kv_write_indices = input_positions
hidden_states = self.embedder(input_token_ids)
# Gemma normalizes the embedding by sqrt(hidden_size).
hidden_states = hidden_states * (self.config.hidden_size**0.5)
# hidden_states should be [batch_size, input_len, hidden_size]
hidden_states = self.model(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_caches=kv_caches,
mask=mask,
)
embedder_weight = self.embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
next_tokens = self.sampler(
embedding=embedder_weight,
hidden_states=hidden_states,
output_positions=output_positions,
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
)
return next_tokens
def load_weights(self, model_path: str):
checkpoint = torch.load(model_path, weights_only=True)
model_state_dict = checkpoint['model_state_dict']
num_attn_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.config.head_dim
hidden_size = self.config.hidden_size
def split(tensor: torch.Tensor, axis: int) -> torch.Tensor:
axis_len = tensor.shape[axis]
split_len = axis_len // self.world_size
split_start = split_len * self.rank
split_end = split_start + split_len
tensor = torch.moveaxis(tensor, axis, 0)
tensor = tensor[split_start:split_end, ...]
tensor = torch.moveaxis(tensor, 0, axis)
return tensor
for k, v in model_state_dict.items():
if k == 'freqs_cis':
continue
if (k == 'model.norm.weight' or re.fullmatch(
r'model.layers.\d+.input_layernorm.weight', k)
or re.fullmatch(
r'model.layers.\d+.post_attention_layernorm.weight',
k) or k.endswith('weight_scaler')):
pass
elif (k == 'embedder.weight' or re.fullmatch(
r'model.layers.\d+.mlp.down_proj.weight', k)):
v = split(v, 1)
elif (re.fullmatch(r'model.layers.\d+.mlp.gate_proj.weight', k)
or re.fullmatch(r'model.layers.\d+.mlp.up_proj.weight', k)):
v = split(v, 0)
elif re.fullmatch(r'model.layers.\d+.self_attn.qkv_proj.weight',
k):
if num_kv_heads <= self.world_size:
num_replicas = self.world_size // num_kv_heads
v = v.reshape(num_attn_heads + num_kv_heads * 2, head_dim,
hidden_size)
query = v[:num_attn_heads, ...]
key = v[num_attn_heads:num_attn_heads + num_kv_heads,
...].repeat(num_replicas, 1, 1)
value = v[-num_kv_heads:, ...].repeat(num_replicas, 1, 1)
v = torch.cat(
(split(query, 0), split(key, 0), split(value, 0)),
dim=0)
else:
v = v.reshape(3, num_attn_heads, head_dim, hidden_size)
v = split(v, 1)
v = v.reshape(-1, hidden_size)
elif re.fullmatch(r'model.layers.\d+.self_attn.o_proj.weight', k):
v = v.reshape(hidden_size, num_attn_heads, head_dim)
v = split(v, 1)
v = v.reshape(hidden_size, -1)
else:
raise ValueError(f'Unrecognized key: {k}')
self.state_dict()[k].copy_(v)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Optional
from sentencepiece import SentencePieceProcessor
class Tokenizer:
def __init__(self, model_path: Optional[str]):
# Reload tokenizer.
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs.
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
"""Converts a string into a list of tokens."""
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
"""Converts a list of tokens into a string."""
return self.sp_model.decode(t)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
import os
from typing import Callable, List, Optional
from fairscale.nn.model_parallel.utils import divide_and_check_no_remainder, split_tensor_along_last_dim
import torch
import torch.ao.quantization.fx._decomposed
import torch.distributed as dist
import torch.distributed._functional_collectives as fc
import torch.distributed.distributed_c10d as c10d
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
EPS = torch.finfo(torch.float32).eps
USE_CUDA = os.environ.get('USE_CUDA', False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
TAG = None
RANKSET = None
GROUP_SIZE = None
def set_g_group():
global TAG
global RANKSET
global GROUP_SIZE
assert USE_CUDA, "This hack is only for PyTorch non-XLA CUDA paths, i.e., eager and inductor."
TAG, RANKSET, GROUP_SIZE = fc._expand_group(c10d._get_default_group())
@dataclass
class TensorQConfig:
dtype: torch.dtype = torch.int8
axis: int = -1
quant_min: int = -128
quant_max: int = 127
symmetric_quant: bool = True
def _find_per_channel_min_max(x: torch.Tensor, axis: int):
x_dim = x.size()
new_axis_list = list(range(len(x_dim)))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = x.permute(new_axis_list)
y = torch.flatten(y, start_dim=1)
return torch.aminmax(y, dim=1)
def _find_qparams(x: torch.Tensor, qconfig: TensorQConfig):
# Only support per-channel symmetric quant to int8 now
axis = qconfig.axis
dtype = qconfig.dtype
symmetric_quant = qconfig.symmetric_quant
quant_min = qconfig.quant_min
quant_max = qconfig.quant_max
assert axis >= 0 and axis < len(x.shape)
assert dtype == torch.int8
min_val, max_val = _find_per_channel_min_max(x, axis)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
if symmetric_quant:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
eps = torch.zeros_like(scale).fill_(EPS)
scale = torch.max(scale, eps)
return scale, None
else:
assert symmetric_quant
def _quantize_to_dtype(
x: torch.Tensor,
qconfig: TensorQConfig,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
):
if zero_point is None:
zero_point = torch.zeros_like(scale)
return torch.ops.quantized_decomposed.quantize_per_channel(
x,
scale,
zero_point,
qconfig.axis,
qconfig.quant_min,
qconfig.quant_max,
qconfig.dtype,
)
def quantize_tensor(x: torch.Tensor, qconfig: TensorQConfig):
scale, zp = _find_qparams(x, qconfig)
x_int = _quantize_to_dtype(x, qconfig, scale, zp)
return x_int, scale, zp
def get_model_parallel_rank():
if USE_CUDA:
return dist.get_rank()
return xm.get_ordinal()
def get_model_parallel_world_size():
if USE_CUDA:
return dist.get_world_size()
return xm.xrt_world_size()
def get_model_parallel_group():
return None
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
return input_
@staticmethod
def backward(ctx, grad_output): # type: ignore
groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
return my_reduce(grad_output, groups, world_size, rank)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-redcue the input from the model parallel region."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
return my_reduce(input_, groups, world_size, rank)
@staticmethod
def backward(ctx, grad_output): # type: ignore
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
return my_split(input_, groups, world_size, rank)
@staticmethod
def backward(ctx, grad_output): # type: ignore
groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
return my_gather(grad_output, groups, world_size, rank)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def forward(ctx, input_, groups, world_size, rank): # type: ignore
ctx.groups, ctx.world_size, ctx.rank = groups, world_size, rank
return my_gather(input_, groups, world_size, rank)
@staticmethod
def backward(ctx, grad_output): # type: ignore
groups, world_size, rank = ctx.groups, ctx.world_size, ctx.rank
return my_split(grad_output, groups, world_size, rank)
# -----------------
# Helper functions.
# -----------------
def copy_to_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _CopyToModelParallelRegion.apply(input_, groups, world_size, rank)
def reduce_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _ReduceFromModelParallelRegion.apply(input_, groups, world_size,
rank)
def scatter_to_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _ScatterToModelParallelRegion.apply(input_, groups, world_size,
rank)
def gather_from_model_parallel_region(input_: torch.Tensor, groups, world_size,
rank) -> torch.Tensor:
return _GatherFromModelParallelRegion.apply(input_, groups, world_size,
rank)
# Below copied from fairscale/nn/model_parallel/layers.py
def my_reduce(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""All-reduce the the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# All-reduce.
if USE_CUDA:
input_ = torch.ops.c10d_functional.all_reduce(input_, "sum", TAG,
RANKSET, GROUP_SIZE)
else:
input_ = xm.all_reduce(xm.REDUCE_SUM, input_, groups=groups)
return input_
def my_split(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""Split the tensor along its last dimension and keep the
corresponding slice.
"""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
output = input_list[rank].contiguous()
return output
def my_gather(input_: torch.Tensor, groups, world_size, rank) -> torch.Tensor:
"""Gather tensors and concatinate along the last dimension."""
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
if USE_CUDA:
last_dim = input_.dim() - 1
# Using all_reduce to achieve all_gather as torch.ops.c10d_functional.all_gather_into_tensor
# is buggy in 16 bits.
size = input_.size(last_dim)
padding = [0] * (2 * input_.dim())
ordinal = rank
left, right = ordinal, world_size - 1 - ordinal
idx = input_.dim() - 1 - last_dim
padding[2 * idx] = left * size
padding[2 * idx + 1] = right * size
output = torch.ops.c10d_functional.all_reduce(F.pad(input_,
padding), "sum",
TAG, RANKSET, GROUP_SIZE)
else:
output = xm.all_gather(input_, dim=-1, groups=groups)
return output
def _initialize_affine_weight(
weight: torch.Tensor,
out_features: int,
in_features: int,
per_partition_size: int,
partition_dim: int,
init_method: Callable[[torch.Tensor], torch.Tensor],
world_size: int,
rank: int,
stride: int = 1,
return_master_weight: bool = False,
) -> Optional[torch.Tensor]:
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk.
"""
# If we only use 1 process for model parallelism, bypass scatter.
if world_size == 1:
init_method(weight)
if return_master_weight:
return weight
return None
# Initialize master weight
master_weight = torch.empty(out_features,
in_features,
dtype=weight.dtype,
requires_grad=False)
init_method(master_weight)
# Split and copy
per_partition_per_stride_size = divide_and_check_no_remainder(
per_partition_size, stride)
weight_list = torch.split(master_weight,
per_partition_per_stride_size,
dim=partition_dim)
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class ParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the embedding dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
init_method: Callable[[torch.Tensor],
torch.Tensor] = init.xavier_normal_,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
) -> None:
super(ParallelEmbedding, self).__init__()
if world_size is None:
self.groups = get_model_parallel_group()
self.world_size = get_model_parallel_world_size()
self.rank = get_model_parallel_rank()
else:
self.groups = groups
self.world_size = world_size
self.rank = rank
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = scale_grad_by_freq
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self._weight = None
self.quant = quant
# Divide the weight matrix along the embedding dimension.
self.embedding_dim_per_partition = divide_and_check_no_remainder(
self.embedding_dim, self.world_size)
# Allocate weights.
if quant:
self.weight = Parameter(
torch.empty(
(self.num_embeddings, self.embedding_dim_per_partition),
dtype=torch.int8,
),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.num_embeddings))
else:
self.weight = Parameter(
torch.Tensor(self.num_embeddings,
self.embedding_dim_per_partition))
# And initialize.
_initialize_affine_weight(
self.weight,
self.num_embeddings,
self.embedding_dim,
self.embedding_dim_per_partition,
1,
init_method,
self.world_size,
self.rank,
stride=1,
return_master_weight=False,
)
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
input_parallel = copy_to_model_parallel_region(input_, self.groups,
self.world_size,
self.rank)
# PyTorch eager and inductor do not accept negative values in the input to embedding
# layers. Take the modulus to avoid this error.
if USE_CUDA:
input_parallel = torch.remainder(input_parallel,
self.weight.shape[0])
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output_parallel = F.embedding(
input_parallel,
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
output = gather_from_model_parallel_region(output_parallel,
self.groups,
self.world_size, self.rank)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gether on output and make Y available to
all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set to
zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be set
to False. It returns the master weights used for initialization.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
gather_output: bool = True,
init_method: Callable[[torch.Tensor],
torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
) -> None:
super(ColumnParallelLinear, self).__init__()
if world_size is None:
self.groups = get_model_parallel_group()
self.world_size = get_model_parallel_world_size()
self.rank = get_model_parallel_rank()
else:
self.groups = groups
self.world_size = world_size
self.rank = rank
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.quant = quant
# Divide the weight matrix along the last dimension.
self.output_size_per_partition = divide_and_check_no_remainder(
out_features, self.world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
if quant:
self.weight = Parameter(
torch.empty(
(self.output_size_per_partition, self.in_features),
dtype=torch.int8,
),
requires_grad=False,
)
self.weight_scaler = Parameter(
torch.Tensor(self.output_size_per_partition))
else:
self.weight = Parameter(
torch.Tensor(self.output_size_per_partition, self.in_features))
if bias:
self.bias = Parameter(torch.Tensor(self.output_size_per_partition))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.output_size_per_partition,
0,
init_method,
self.world_size,
self.rank,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(
self.weight.data.transpose(0, 1),
self.groups,
self.world_size,
self.rank,
).transpose_(0, 1)
def set_quantize(self):
assert not self.quant
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features),
dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(
torch.Tensor(self.output_size_per_partition))
self.quant = True
def quantize(self):
assert not self.quant
fp_w = deepcopy(self.weight.data)
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.output_size_per_partition, self.in_features),
dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(
torch.Tensor(self.output_size_per_partition))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_, self.groups,
self.world_size,
self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
scaled_weight = self.weight * self.weight_scaler
output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
elif self.quant:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output_parallel = output_parallel * self.weight_scaler
else:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_model_parallel_region(output_parallel,
self.groups,
self.world_size,
self.rank)
else:
output = output_parallel
return output
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
in_features: first dimension of matrix A.
out_features: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already split
across the GPUs and we do not split again.
init_method: method to initialize weights. Note that bias is always set to
zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be set
to False. It returns the master weights used for initialization.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
input_is_parallel: bool = False,
init_method: Callable[[torch.Tensor],
torch.Tensor] = init.xavier_normal_,
stride: int = 1,
keep_master_weight_for_test: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
groups: Optional[List] = None,
quant: bool = False,
):
super(RowParallelLinear, self).__init__()
if world_size is None:
self.groups = get_model_parallel_group()
self.world_size = get_model_parallel_world_size()
self.rank = get_model_parallel_rank()
else:
self.groups = groups
self.world_size = world_size
self.rank = rank
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
self.input_is_parallel = input_is_parallel
self.quant = quant
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide_and_check_no_remainder(
in_features, self.world_size)
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
if quant:
self.weight = Parameter(
torch.empty(
(self.out_features, self.input_size_per_partition),
dtype=torch.int8,
),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
else:
self.weight = Parameter(
torch.Tensor(self.out_features, self.input_size_per_partition))
if bias:
self.bias = Parameter(torch.Tensor(self.out_features))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
# Initialize weight.
self.master_weight = _initialize_affine_weight(
self.weight,
self.out_features,
self.in_features,
self.input_size_per_partition,
1,
init_method,
self.world_size,
self.rank,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
def get_master_weight(self) -> torch.Tensor:
return gather_from_model_parallel_region(self.weight.data, self.groups,
self.world_size, self.rank)
def set_quantize(self):
assert not self.quant
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition),
dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
self.quant = True
def quantize(self):
assert not self.quant
fp_w = deepcopy(self.weight.data)
orig_dtype = fp_w.dtype
fp_w = fp_w.to(torch.float32)
self.weight = Parameter(
torch.empty((self.out_features, self.input_size_per_partition),
dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = Parameter(torch.Tensor(self.out_features))
qconfig = TensorQConfig(axis=0)
self.weight.data, scale, zero_point = quantize_tensor(fp_w, qconfig)
self.weight_scaler.data = scale.to(orig_dtype)
self.quant = True
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(
input_, self.groups, self.world_size, self.rank)
# Matrix multiply.
if self.quant and USE_CUDA:
# GPUs do not support mixed int8 bf16 computation. Scale int8 weights to bf16 before linear.
scaled_weight = self.weight * self.weight_scaler
output_parallel = F.linear(input_parallel, scaled_weight, self.bias)
elif self.quant:
output_parallel = F.linear(input_parallel, self.weight, self.bias)
output_parallel = output_parallel * self.weight_scaler
else:
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
output_ = reduce_from_model_parallel_region(output_parallel,
self.groups,
self.world_size, self.rank)
if self.bias is not None:
output = output_ + self.bias
else:
output = output_
return output
#/bin/bash
python scripts/run.py --ckpt="gemma-2b-pytorch/gemma-2b-it.ckpt" --variant=2b --prompt="The meaning of life is" --device=cuda
# 模型编码
modelCode=560
# 模型名称
modelName=gemma_pytorch
# 模型描述
modelDescription=谷歌发布的号称“全球性能最强大、轻量级”的新一代开源2B小模型Gemma,打响小模型战争。
# 应用场景
appScenario=推理,制造,广媒,金融,能源,医疗,家居,教育
# 框架类型
frameType=pytorch
fairscale == 0.4.13
numpy == 1.24.4
immutabledict == 4.1.0
sentencepiece == 0.1.99
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import random
import numpy as np
import torch
from gemma import config
from gemma import model as gemma_model
import time
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def main(args):
# Construct the model config.
model_config = config.get_model_config(args.variant)
model_config.dtype = "float32" if args.device == "cpu" else "float16"
model_config.quant = args.quant
# Seed random.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Create the model and load the weights.
device = torch.device(args.device)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(args.ckpt)
model = model.to(device).eval()
print("Model loading done")
# Generate the response.
# start_time = time.time()
result = model.generate(args.prompt, device, output_len=args.output_len)
# print("infer time:", time.time() - start_time, "s")
# Print the prompts and results.
print('======================================')
print(f'PROMPT: {args.prompt}')
print(f'RESULT: {result}')
print('======================================')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "7b"])
parser.add_argument("--device",
type=str,
default="cpu",
choices=["cpu", "cuda"])
parser.add_argument("--output_len", type=int, default=100)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import os
import random
import socket
import sys
from typing import List
import numpy as np
import torch
import torch.multiprocessing
from gemma.config import GemmaConfig, get_model_config
from gemma.model_xla import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import gemma.xla_model_parallel as xla_model_parallel
USE_CUDA = os.environ.get('USE_CUDA', False)
if not USE_CUDA:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
else:
# Choose an available port.
with contextlib.closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
MASTER_PORT = str(s.getsockname()[1])
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
def generate(i: int, model_config: GemmaConfig, ckpt_path: str,
prompts: List[str], output_lens: List[int],
temperatures: List[float], top_ps: List[float],
top_ks: List[int], seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
local_rank = int(os.environ.get("LOCAL_RANK", 0))
device = torch.device("cuda", local_rank)
torch.cuda.set_device(local_rank)
else:
device = xm.xla_device()
xm.set_rng_state(seed, device)
rank = xla_model_parallel.get_model_parallel_rank()
world_size = xla_model_parallel.get_model_parallel_world_size()
if rank > 0:
sys.stdout = open(os.devnull, 'w')
# build, load and compile model.
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config, world_size, rank, device)
model.load_weights(ckpt_path)
model = model.to(device).eval()
# create tokenizer.
tokenizer = Tokenizer(model_config.tokenizer)
prompt_tokens = [tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
batch_size = len(prompts)
assert batch_size == len(temperatures)
assert batch_size == len(top_ps)
assert batch_size == len(top_ks)
max_seq_len = max([len(p) + o for p, o in zip(prompt_tokens, output_lens)])
assert max_seq_len <= model_config.max_position_embeddings
if model_config.num_key_value_heads < world_size:
assert world_size % model_config.num_key_value_heads == 0
n_local_heads = 1
else:
assert model_config.num_key_value_heads % world_size == 0
n_local_heads = model_config.num_key_value_heads // world_size
# build KV caches
kv_caches = []
for _ in range(model_config.num_hidden_layers):
k_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
v_cache = torch.zeros(
size=(batch_size, max_seq_len, n_local_heads,
model_config.head_dim),
dtype=model_config.get_dtype(),
device=device,
)
kv_caches.append((k_cache, v_cache))
# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
tokenizer.pad_id,
dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != tokenizer.pad_id
input_token_ids_tensor = input_token_ids_tensor.to(device)
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(device)
temperatures_tensor = torch.FloatTensor(temperatures).to(device)
top_ps_tensor = torch.FloatTensor(top_ps).to(device)
top_ks_tensor = torch.LongTensor(top_ks).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(device)
if not USE_CUDA:
xm.mark_step()
# Prefill up to min_prompt_len tokens, then treat other prefill as decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids = model(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)
curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)
input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(device)
output_index = output_index + 1
if not USE_CUDA:
xm.mark_step()
# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) +
output_lens[i]]
if tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(tokenizer.decode(trimmed_output))
for prompt, result in zip(prompts, results):
print('======================================')
print(f'PROMPT: {prompt}')
print(f'RESULT: {result}')
print('======================================')
def main(args):
model_config = get_model_config(args.variant)
model_config.quant = args.quant
prompts = [args.prompt]
n = len(prompts)
output_lengths = [args.output_len] * n
temperatures = [0.95] * n
top_ps = [1.0] * n
top_ks = [100] * n
if USE_CUDA:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = MASTER_PORT
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
"nccl",
rank=int(os.environ.get("RANK", 0)),
world_size=int(os.environ.get("WORLD_SIZE", 1)))
xla_model_parallel.set_g_group()
torch.multiprocessing.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
else:
xmp.spawn(
generate,
args=(
model_config,
args.ckpt,
prompts,
output_lengths,
temperatures,
top_ps,
top_ks,
args.seed,
),
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "7b"])
parser.add_argument("--output_len", type=int, default=4)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()
main(args)
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
from typing import List
import setuptools
ROOT_DIR = os.path.dirname(__file__)
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
def read_readme() -> str:
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()
def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements
setuptools.setup(
name="gemma",
version="0.1",
author="Gemma contributors",
license="Apache 2.0",
description=("Gemma model implementation"),
long_description=read_readme(),
long_description_content_type="text/markdown",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(exclude=("benchmarks", "docs",
"examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment