# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Gemma model implementation.""" import re import torch from torch import nn import torch.nn.functional as F from typing import Any, List, Optional, Sequence, Tuple, Union from gemma import config as gemma_config from gemma import tokenizer class Sampler(nn.Module): def __init__(self, vocab_size: int): super().__init__() self.vocab_size = vocab_size @torch.no_grad() def forward( self, embedding: torch.Tensor, hidden_states: torch.Tensor, output_positions: torch.Tensor, temperatures: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Select the last element for each sequence. # (batch_size, input_len, hidden_size) -> (batch_size, hidden_size) hidden_states = hidden_states.index_select( 1, output_positions).squeeze(dim=1) logits = torch.matmul(hidden_states, embedding.t()) if embedding_bias is not None: logits += embedding_bias if temperatures is None: return torch.argmax(logits, dim=-1).squeeze(dim=-1) # Apply temperature scaling. logits.div_(temperatures.unsqueeze(dim=1)) # Calculate probabilities with softmax. probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) # Apply top-p, top-k. probs_sum = torch.cumsum(probs_sort, dim=-1) top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1) probs_sort = torch.where(top_ps_mask, 0, probs_sort) top_ks_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1) top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1) probs_sort = torch.where(top_ks_mask, 0, probs_sort) # Re-normalization. probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) probs = torch.gather(probs_sort, dim=-1, index=torch.argsort(probs_idx, dim=-1)) next_token_ids = torch.multinomial(probs, num_samples=1, replacement=True).squeeze(dim=-1) return next_token_ids def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """Precomputes the frequency cis.""" freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """Applies the rotary embedding to the query and key tensors.""" x_ = torch.view_as_complex( torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1)) x_out = torch.view_as_real(x_ * freqs_cis).type_as(x) x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2) x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2], -1).transpose(1, 2) return x_out class Linear(nn.Module): def __init__(self, in_features: int, out_features: int, quant: bool): super().__init__() if quant: self.weight = nn.Parameter( torch.empty((out_features, in_features), dtype=torch.int8), requires_grad=False, ) self.weight_scaler = nn.Parameter(torch.Tensor(out_features)) else: self.weight = nn.Parameter( torch.empty((out_features, in_features)), requires_grad=False, ) self.quant = quant def forward(self, x): weight = self.weight if self.quant: weight = weight * self.weight_scaler.unsqueeze(-1) output = F.linear(x, weight) return output class Embedding(nn.Module): def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): super().__init__() if quant: self.weight = nn.Parameter( torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), requires_grad=False, ) self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) else: self.weight = nn.Parameter( torch.empty((num_embeddings, embedding_dim)), requires_grad=False, ) self.quant = quant def forward(self, x): weight = self.weight if self.quant: weight = weight * self.weight_scaler.unsqueeze(-1) output = F.embedding(x, weight) return output class RMSNorm(torch.nn.Module): def __init__( self, dim: int, eps: float = 1e-6, add_unit_offset: bool = True, ): super().__init__() self.eps = eps self.add_unit_offset = add_unit_offset self.weight = nn.Parameter(torch.zeros(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): x = self._norm(x.float()).type_as(x) if self.add_unit_offset: output = x * (1 + self.weight) else: output = x * self.weight return output class GemmaMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, quant: bool, ): super().__init__() self.gate_proj = Linear(hidden_size, intermediate_size, quant) self.up_proj = Linear(hidden_size, intermediate_size, quant) self.down_proj = Linear(intermediate_size, hidden_size, quant) def forward(self, x): gate = self.gate_proj(x) gate = F.gelu(gate, approximate="tanh") up = self.up_proj(x) fuse = gate * up outputs = self.down_proj(fuse) return outputs class GemmaAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, head_dim: int, quant: bool, ): super().__init__() self.num_heads = num_heads self.num_kv_heads = num_kv_heads assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.hidden_size = hidden_size self.head_dim = head_dim self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.qkv_proj = Linear( self.hidden_size, (self.num_heads + 2 * self.num_kv_heads) * self.head_dim, quant=quant) self.o_proj = Linear( self.num_heads * self.head_dim, self.hidden_size, quant=quant) def forward( self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor, ) -> torch.Tensor: hidden_states_shape = hidden_states.shape assert len(hidden_states_shape) == 3 batch_size, input_len, _ = hidden_states_shape qkv = self.qkv_proj(hidden_states) xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) xq = xq.view(batch_size, -1, self.num_heads, self.head_dim) xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim) xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim) # Positional embedding. xq = apply_rotary_emb(xq, freqs_cis=freqs_cis) xk = apply_rotary_emb(xk, freqs_cis=freqs_cis) # Write new kv cache. # [batch_size, input_len, n_local_kv_heads, head_dim] k_cache, v_cache = kv_cache k_cache.index_copy_(1, kv_write_indices, xk) v_cache.index_copy_(1, kv_write_indices, xv) key = k_cache value = v_cache if self.num_kv_heads != self.num_heads: # [batch_size, max_seq_len, n_local_heads, head_dim] key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2) value = torch.repeat_interleave(value, self.num_queries_per_kv, dim=2) # [batch_size, n_local_heads, input_len, head_dim] q = xq.transpose(1, 2) # [batch_size, n_local_heads, max_seq_len, head_dim] k = key.transpose(1, 2) v = value.transpose(1, 2) # [batch_size, n_local_heads, input_len, max_seq_len] scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling scores = scores + mask scores = F.softmax(scores.float(), dim=-1).type_as(q) # [batch_size, n_local_heads, input_len, head_dim] output = torch.matmul(scores, v) # [batch_size, input_len, hidden_dim] output = (output.transpose(1, 2).contiguous().view( batch_size, input_len, -1)) output = self.o_proj(output) return output class GemmaDecoderLayer(nn.Module): def __init__( self, config: gemma_config.GemmaConfig, ): super().__init__() self.self_attn = GemmaAttention( hidden_size=config.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, head_dim=config.head_dim, quant=config.quant, ) self.mlp = GemmaMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, quant=config.quant, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor, ) -> torch.Tensor: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_cache, mask=mask, ) hidden_states = residual + hidden_states # MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class GemmaModel(nn.Module): def __init__(self, config: gemma_config.GemmaConfig): super().__init__() self.config = config self.vocab_size = config.vocab_size self.layers = nn.ModuleList() for _ in range(config.num_hidden_layers): self.layers.append(GemmaDecoderLayer(config)) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, freqs_cis: torch.Tensor, kv_write_indices: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], mask: torch.Tensor, ) -> torch.Tensor: for i in range(len(self.layers)): layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_cache=kv_caches[i], mask=mask, ) hidden_states = self.norm(hidden_states) return hidden_states class GemmaForCausalLM(nn.Module): def __init__( self, config: gemma_config.GemmaConfig, ): super().__init__() self.config = config assert config.hidden_size % config.num_attention_heads == 0 max_seq_len = config.max_position_embeddings head_dim = config.head_dim vocab_size = config.vocab_size self.tokenizer = tokenizer.Tokenizer(config.tokenizer) self.embedder = Embedding(vocab_size, config.hidden_size, config.quant) self.model = GemmaModel(config) self.sampler = Sampler(vocab_size) # Pre-compute rotary embedding table. rope_theta = getattr(config, 'rope_theta', 10000) freqs_cis = precompute_freqs_cis(head_dim, max_seq_len * 2, theta=rope_theta) self.register_buffer('freqs_cis', freqs_cis) @torch.no_grad() def forward( self, input_token_ids: torch.Tensor, input_positions: torch.Tensor, kv_write_indices: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], mask: torch.Tensor, output_positions: torch.Tensor, temperatures: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor, **kwargs, ) -> torch.Tensor: freqs_cis = self.freqs_cis.index_select(0, input_positions) kv_write_indices = input_positions # [batch_size, input_len, hidden_size] hidden_states = self.embedder(input_token_ids) # Gemma normalizes the embedding by sqrt(hidden_size). hidden_states = hidden_states * (self.config.hidden_size**0.5) hidden_states = self.model( hidden_states=hidden_states, freqs_cis=freqs_cis, kv_write_indices=kv_write_indices, kv_caches=kv_caches, mask=mask, ) embedder_weight = self.embedder.weight if self.config.quant: embedder_weight = ( embedder_weight * self.embedder.weight_scaler.unsqueeze(-1)) next_tokens = self.sampler( embedding=embedder_weight, hidden_states=hidden_states, output_positions=output_positions, temperatures=temperatures, top_ps=top_ps, top_ks=top_ks, ) return next_tokens def generate( self, prompts: Union[str, Sequence[str]], device: Any, output_len: int = 100, temperature: float = 0.95, top_p: float = 1.0, top_k: int = 100, ) -> Union[str, Sequence[str]]: """Generates responses for given prompts using Gemma model.""" # If a single prompt is provided, treat it as a batch of 1. is_str_prompt = isinstance(prompts, str) if is_str_prompt: prompts = [prompts] batch_size = len(prompts) prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts] min_prompt_len = min(len(p) for p in prompt_tokens) max_prompt_len = max(len(p) for p in prompt_tokens) max_seq_len = max_prompt_len + output_len assert max_seq_len <= self.config.max_position_embeddings # build KV caches kv_caches = [] for _ in range(self.config.num_hidden_layers): size = (batch_size, max_seq_len, self.config.num_key_value_heads, self.config.head_dim) dtype = self.config.get_dtype() k_cache = torch.zeros(size=size, dtype=dtype, device=device) v_cache = torch.zeros(size=size, dtype=dtype, device=device) kv_caches.append((k_cache, v_cache)) # prepare inputs token_ids_tensor = torch.full((batch_size, max_seq_len), self.tokenizer.pad_id, dtype=torch.int64) input_token_ids_tensor = torch.full((batch_size, min_prompt_len), self.tokenizer.pad_id, dtype=torch.int64) for i, p in enumerate(prompt_tokens): token_ids_tensor[i, :len(p)] = torch.tensor(p) input_token_ids_tensor[i, :min_prompt_len] = torch.tensor( p[:min_prompt_len]) token_ids_tensor = token_ids_tensor.to(device) input_token_ids_tensor = input_token_ids_tensor.to(device) prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id input_positions_tensor = torch.arange(0, min_prompt_len, dtype=torch.int64).to(device) mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len), -2.3819763e38).to(torch.float) mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device) curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to( device) temperatures_tensor = torch.FloatTensor([temperature] * batch_size).to( device) top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device) top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device) output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to( device) # Prefill up to min_prompt_len tokens, then treat other prefill as # decode and ignore output. for i in range(max_seq_len - min_prompt_len): next_token_ids = self( input_token_ids=input_token_ids_tensor, input_positions=input_positions_tensor, kv_write_indices=None, kv_caches=kv_caches, mask=curr_mask_tensor, output_positions=output_positions_tensor, temperatures=temperatures_tensor, top_ps=top_ps_tensor, top_ks=top_ks_tensor, ) curr_prompt_mask = prompt_mask_tensor.index_select( 1, output_index).squeeze(dim=1) curr_token_ids = token_ids_tensor.index_select( 1, output_index).squeeze(dim=1) output_token_ids = torch.where(curr_prompt_mask, curr_token_ids, next_token_ids).unsqueeze(dim=1) token_ids_tensor.index_copy_(1, output_index, output_token_ids) input_token_ids_tensor = output_token_ids input_positions_tensor = output_index.unsqueeze(dim=-1) curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor) output_positions_tensor = torch.tensor(0, dtype=torch.int64).to( device) output_index = output_index + 1 # Detokenization. token_ids = token_ids_tensor.tolist() results = [] for i, tokens in enumerate(token_ids): trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i]) + output_len] if self.tokenizer.eos_id in trimmed_output: eos_index = trimmed_output.index(self.tokenizer.eos_id) trimmed_output = trimmed_output[:eos_index] results.append(self.tokenizer.decode(trimmed_output)) # If a string was provided as input, return a string as output. return results[0] if is_str_prompt else results def load_weights(self, model_path: str): self.load_state_dict( torch.load( model_path, mmap=True, weights_only=True, )['model_state_dict'], strict=False, )