Unverified Commit 9c256a17 authored by Ceng's avatar Ceng Committed by GitHub
Browse files

issue/106 适配模型9G7B

parent 39bea30a
...@@ -173,7 +173,7 @@ class LlamaConfig(PretrainedConfig): ...@@ -173,7 +173,7 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings=False, tie_word_embeddings=False,
rope_theta=10000.0, rope_theta=10000.0,
rope_scaling=None, rope_scaling=None,
attention_bias=False, attention_bias=True,
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False, mlp_bias=False,
head_dim=None, head_dim=None,
......
...@@ -157,7 +157,7 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -157,7 +157,7 @@ class LlamaAttention(infinicore.nn.Module):
self.o_proj = infinicore.nn.Linear( self.o_proj = infinicore.nn.Linear(
self.num_attention_heads * self.head_dim, self.num_attention_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=attention_bias, bias=False,
**kwargs, **kwargs,
) )
......
import sys
import os
import argparse
import time
import re
from datasets import load_dataset
import infinicore
import infinilm
from infinilm.models.llama import AutoLlamaModel
from infinilm.modeling_utils import get_model_state_dict
from infinilm.distributed import DistConfig
from abc import ABC, abstractmethod
class BaseBenchmark(ABC):
"""Base class for benchmark evaluation with common tokenizer and generation utilities"""
def encode_text(self, text):
"""Encode text to token IDs - reused across backends"""
return self.tokenizer.encode(text)
def decode_token(self, token_id):
"""Decode token ID to text - reused across backends"""
return self.tokenizer.decode(token_id)
@abstractmethod
def render_input_content(self, *args, **kwargs):
"""Render input content - benchmark-specific implementation"""
pass
@abstractmethod
def generate(self, *args, **kwargs):
"""Generate response - benchmark-specific implementation"""
pass
@abstractmethod
def _generate_step(self, tokens, max_steps, topp_, topk_, temperature_):
"""Backend-specific generation implementation"""
pass
class InfiniLMBenchmark(BaseBenchmark):
"""Wrapper class for InfiniLM cpp backend for benchmark evaluation"""
def __init__(self, model_dir_path, device_type_str="cpu", ndev=1, backend="cpp", benchmark="ceval"):
import transformers
self.benchmark = benchmark
# Map device type string to infinicore device
device_map = {
"cpu": "cpu",
"nvidia": "cuda",
"cambricon": "cambricon",
"ascend": "ascend",
"metax": "metax",
"moore": "moore",
"iluvatar": "iluvatar",
"kunlun": "kunlun",
"hygon": "hygon",
}
device_name = device_map.get(device_type_str.lower(), "cpu")
# CUDA_VISIBLE_DEVICES is automatically respected by CUDA runtime API
# When CUDA_VISIBLE_DEVICES=5 is set, CUDA only sees device 5 as device 0
# So device index 0 will automatically map to the first visible device
self.device = infinicore.device(device_name, 0)
self.dtype = infinicore.bfloat16
# Load config and tokenizer
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
import json
self.config_dict = json.load(f)
# Align tokenizer initialization with jiuge backend (010)
# Match the exact same initialization logic based on model type
model_type = self.config_dict.get("model_type", "")
if model_type == "llama":
# For llama models: no trust_remote_code (matches jiuge line 465)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
elif model_type in ["fm9g", "minicpm", "fm9g7b"]:
# For fm9g/minicpm/fm9g7b models: use trust_remote_code=True (matches jiuge lines 493-495, 518-520)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
elif model_type in ["qwen2", "qwen3"]:
# For qwen2/qwen3 models: no trust_remote_code (matches jiuge line 534-536)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path)
else:
# Default: use trust_remote_code=True for other models
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
eos_token_id = self.config_dict.get("eos_token_id")
self.eos_token_id = (
[eos_token_id] if isinstance(eos_token_id, int) else eos_token_id
)
# Create model with cpp backend
print("Loading model with cpp backend...")
self.model = AutoLlamaModel.from_pretrained(
model_dir_path,
device=self.device,
dtype=self.dtype,
backend=backend,
distributed_config=DistConfig(ndev),
)
# Enable KV cache for generation
self.model.use_cache = True
# Load weights
print("Loading model weights...")
model_param_infini = get_model_state_dict(
model_dir_path,
device=self.device,
dtype=self.dtype,
)
self.model.load_state_dict(model_param_infini)
print("Model loaded successfully")
def max_context_len(self):
return self.config_dict.get("max_position_embeddings", 2048)
def render_input_content(self, *args, **kwargs):
"""Render input content based on benchmark type"""
if self.benchmark == "ceval":
return self._render_ceval(*args, **kwargs)
elif self.benchmark == "mmlu":
return self._render_mmlu(*args, **kwargs)
else:
raise ValueError(f"Unknown benchmark: {self.benchmark}")
def _render_ceval(self, conversation):
"""Render C-Eval conversation to input content"""
return (
self.tokenizer.apply_chat_template(
conversation=conversation,
add_generation_prompt=True,
tokenize=False,
)
+ "正确答案是"
)
def _render_mmlu(self, question, choices):
"""Render MMLU question and choices to input content"""
choices_text = "\n".join([f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)])
instruction = (
"You are a multiple-choice question solver. "
"Select the correct option and respond with only the letter A, B, C, or D."
)
prompt = f"{instruction}\n\nQuestion: {question}\n{choices_text}\nAnswer:"
# Use chat template if available, otherwise return plain text
if hasattr(self.tokenizer, 'apply_chat_template'):
conversation = [
{"role": "system", "content": instruction},
{"role": "user", "content": f"{question}\n{choices_text}\nAnswer:"}
]
try:
return self.tokenizer.apply_chat_template(
conversation=conversation,
add_generation_prompt=True,
tokenize=False,
)
except Exception:
return prompt
return prompt
def generate(self, *args, max_steps=500, topp_=1.0, topk_=1, temperature_=1.0):
"""Generate response based on benchmark type"""
# Render input content
input_content = self.render_input_content(*args)
print(input_content, end="", flush=True)
# Encode input
tokens = self.encode_text(input_content)
# Delegate to backend-specific generation implementation
output_content, avg_time = self._generate_step(
tokens, max_steps, topp_, topk_, temperature_
)
return output_content, avg_time
def _generate_step(self, tokens, max_steps, topp_, topk_, temperature_):
"""
InfiniLM cpp backend-specific generation implementation
NOTE: Validation confirmed input configs are identical between backends.
The issue was that manual generation loop called InferEngine.generate() which
doesn't maintain KV cache. Solution: Use model's built-in generate() method
which properly handles KV cache through GenerationMixin.
"""
# Convert tokens to infinicore format
input_ids_list = [tokens]
input_ids = infinicore.from_list(input_ids_list, dtype=infinicore.int64).to(self.device)
# Use model's built-in generate() method which properly handles KV cache
# Pass sampling parameters (temperature, topk, topp) via kwargs
output_tokens_list, output_content = self.model.generate(
input_ids=input_ids,
max_new_tokens=max_steps,
tokenizer=self.tokenizer,
stop_on_eos=True,
temperature=temperature_,
topk=topk_,
topp=topp_,
)
# Calculate average time (GenerationMixin doesn't return timing info)
# We'll use a placeholder since the timing info isn't available
print("\n")
avg_time = 0.0 # GenerationMixin doesn't expose per-step timing
print(f"Time per step: N/A (using GenerationMixin.generate)")
return output_content, avg_time
def destroy_model_instance(self):
# Cleanup if needed
del self.model
print("Model destroyed")
def extract_answer_ceval(output_content, answer):
"""Extract predicted answer from C-Eval output"""
output_upper = output_content.upper().strip()
position = 0
ABCD = output_upper[position : position + 2]
return answer in ABCD
def extract_answer_mmlu(output_content):
"""Extract predicted answer from MMLU output (returns 0-3 index or None)"""
output_upper = output_content.upper().strip()
# Find first meaningful token
match = re.search(r"\b([ABCD])\b", output_upper)
if match:
return ord(match.group(1)) - ord('A')
else:
match_num = re.search(r"\b([0-3])\b", output_upper)
if match_num:
return int(match_num.group(1))
return None
def test():
# Parse arguments manually to handle device flags properly
if len(sys.argv) < 4:
print(
"Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench [ceval|mmlu] [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N]"
)
sys.exit(1)
# Parse device flag (first argument)
device_flag = sys.argv[1]
model_path = sys.argv[2]
# Parse optional arguments
backend = "cpp"
ndev = 1
benchmark = None
subject = None # For MMLU
dataset_name = "middle_school_mathematics" # For C-Eval
num_samples = None
max_new_tokens = 500
i = 3
while i < len(sys.argv):
if sys.argv[i] == "--bench" and i + 1 < len(sys.argv):
benchmark = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--backend" and i + 1 < len(sys.argv):
backend = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--ndev" and i + 1 < len(sys.argv):
ndev = int(sys.argv[i + 1])
i += 2
elif sys.argv[i] == "--subject" and i + 1 < len(sys.argv):
subject = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--dataset" and i + 1 < len(sys.argv):
dataset_name = sys.argv[i + 1]
i += 2
elif sys.argv[i] == "--num_samples" and i + 1 < len(sys.argv):
num_samples = int(sys.argv[i + 1])
i += 2
elif sys.argv[i] == "--max_new_tokens" and i + 1 < len(sys.argv):
max_new_tokens = int(sys.argv[i + 1])
i += 2
else:
i += 1
if benchmark is None:
print("Error: --bench argument is required. Choose 'ceval' or 'mmlu'")
sys.exit(1)
if benchmark not in ["ceval", "mmlu"]:
print(f"Error: Unknown benchmark '{benchmark}'. Choose 'ceval' or 'mmlu'")
sys.exit(1)
# Parse device type
device_type_str = "cpu"
if device_flag == "--cpu":
device_type_str = "cpu"
elif device_flag == "--nvidia":
device_type_str = "nvidia"
elif device_flag == "--cambricon":
device_type_str = "cambricon"
elif device_flag == "--ascend":
device_type_str = "ascend"
elif device_flag == "--metax":
device_type_str = "metax"
elif device_flag == "--moore":
device_type_str = "moore"
elif device_flag == "--iluvatar":
device_type_str = "iluvatar"
elif device_flag == "--kunlun":
device_type_str = "kunlun"
elif device_flag == "--hygon":
device_type_str = "hygon"
else:
print(
"Usage: python test_benchmark.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] <path/to/model_dir> --bench [ceval|mmlu] [--backend cpp] [--ndev N] [--subject SUBJECT] [--num_samples N] [--max_new_tokens N]"
)
sys.exit(1)
# Load dataset based on benchmark
if benchmark == "ceval":
# Load C-Eval dataset
# https://huggingface.co/datasets/ceval/ceval-exam/tree/main/middle_school_geography
print(f"Loading C-Eval dataset (dataset: {dataset_name})...")
try:
dataset = load_dataset(r"ceval/ceval-exam", name=dataset_name)
samples = dataset["val"]
# Convert Dataset to list if needed
if hasattr(samples, 'to_list'):
samples = samples.to_list()
else:
samples = list(samples)
except Exception as e:
print(f"Error loading dataset: {e}")
print("Available datasets: middle_school_mathematics, high_school_history, high_school_chinese, high_school_physics, middle_school_geography, middle_school_physics")
sys.exit(1)
elif benchmark == "mmlu":
# Load MMLU dataset
# https://huggingface.co/datasets/cais/mmlu
if subject is None:
subject = "all"
print(f"Loading MMLU dataset (subject: {subject})...")
try:
if subject == "all":
dataset = load_dataset("cais/mmlu", "all")
# Combine all subjects into a single dataset
samples = []
for subject_name in dataset.keys():
if subject_name in ["train", "validation", "test"]:
continue
# Convert Dataset to list
test_data = dataset[subject_name]["test"]
if hasattr(test_data, 'to_list'):
samples.extend(test_data.to_list())
else:
samples.extend(list(test_data))
else:
dataset = load_dataset("cais/mmlu", subject)
test_data = dataset["test"]
# Convert Dataset to list
if hasattr(test_data, 'to_list'):
samples = test_data.to_list()
else:
samples = list(test_data)
except Exception as e:
print(f"Error loading dataset: {e}")
print("Available subjects: abstract_algebra, anatomy, astronomy, business_ethics, etc.")
print("Use --subject all to load all subjects")
sys.exit(1)
print(f"Loaded {len(samples)} samples")
# Limit number of samples if specified
if num_samples is not None and num_samples > 0:
original_count = len(samples)
samples = samples[:num_samples]
print(f"Limited to {len(samples)} samples for validation (from {original_count} total)")
# Create model based on backend
if backend != "010":
model = InfiniLMBenchmark(model_path, device_type_str, ndev, backend, benchmark)
else:
print(f"test 010 backend by scripts/test_ceval.py")
exit(0)
# Test with first sample if available
if len(samples) > 0:
sample = samples[0]
if benchmark == "ceval":
input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。"
test_conversation = [
{
"role": "system",
"content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。",
},
{"role": "user", "content": input_content},
]
test_output, _ = model.generate(test_conversation, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0)
elif benchmark == "mmlu":
question = sample['question']
choices = sample['choices']
test_output, _ = model.generate(question, choices, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0)
print(f"\nTest output: {test_output}")
answers_list = []
for idx, sample in enumerate(samples):
if benchmark == "ceval":
input_content = f"'question':{sample['question']},'A': {sample['A']}, 'B':{sample['B']}, 'C': {sample['C']},'D': {sample['D']}。"
conversation = [
{
"role": "system",
"content": "请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。",
},
{"role": "user", "content": input_content},
]
answer = sample["answer"]
output_content, avg_time = model.generate(
conversation, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0
)
is_correct = extract_answer_ceval(output_content, answer)
answers_list.append({
"id": sample.get("id", idx),
"output_content": output_content,
"answer": answer,
"is_correct": is_correct
})
if benchmark == "ceval":
print("标准答案:", answer)
elif benchmark == "mmlu":
question = sample['question']
choices = sample['choices']
answer_idx = sample['answer'] # MMLU answer is 0-3 index
output_content, avg_time = model.generate(
question, choices, max_steps=max_new_tokens, topp_=1.0, topk_=1, temperature_=1.0
)
predicted_answer = extract_answer_mmlu(output_content)
# Convert answer index to letter for display
answer_letter = chr(65 + answer_idx) if answer_idx < 4 else "?"
predicted_letter = chr(65 + predicted_answer) if predicted_answer is not None and predicted_answer < 4 else "?"
print(f"Sample {idx}: Correct answer: {answer_letter} ({answer_idx}), Predicted: {predicted_letter} ({predicted_answer})")
answers_list.append({
"id": idx,
"output_content": output_content,
"answer": answer_idx,
"predicted": predicted_answer
})
model.destroy_model_instance()
print("-------------------------------------------------------------")
# Evaluate results
true_num = 0
all_num = 0
for cont in answers_list:
id = cont["id"]
all_num = all_num + 1
if benchmark == "ceval":
answer = cont["answer"]
is_correct = cont["is_correct"]
if is_correct:
true_num = true_num + 1
print(f"id {id} : ", "正确")
else:
print(f"id {id}: ", "错误")
elif benchmark == "mmlu":
answer = cont["answer"]
predicted = cont["predicted"]
if predicted is not None and predicted == answer:
true_num = true_num + 1
print(f"id {id}: Correct")
else:
answer_letter = chr(65 + answer) if answer < 4 else "?"
predicted_letter = chr(65 + predicted) if predicted is not None and predicted < 4 else "?"
print(f"id {id}: Wrong (correct: {answer_letter}, predicted: {predicted_letter})")
accuracy = true_num / all_num if all_num > 0 else 0.0
if benchmark == "ceval":
print(f"成绩: {true_num}/{all_num}", accuracy)
else:
print(f"Accuracy: {true_num}/{all_num} = {accuracy:.2%}")
if __name__ == "__main__":
test()
...@@ -4,7 +4,6 @@ Test script to validate forward pass across different backends and dtypes. ...@@ -4,7 +4,6 @@ Test script to validate forward pass across different backends and dtypes.
Tests: Tests:
1. Python backend with bfloat16 1. Python backend with bfloat16
2. C++ backend with float32
3. C++ backend with bfloat16 3. C++ backend with bfloat16
This script runs a prefill step (full sequence forward pass with KV cache) This script runs a prefill step (full sequence forward pass with KV cache)
...@@ -81,6 +80,12 @@ def get_args(): ...@@ -81,6 +80,12 @@ def get_args():
default="How are you", default="How are you",
help="Test prompt (default: 'How are you')", help="Test prompt (default: 'How are you')",
) )
parser.add_argument(
"--num_decode_steps",
type=int,
default=2,
help="Number of decode steps to run after prefill (default: 2)",
)
return parser.parse_args() return parser.parse_args()
...@@ -116,9 +121,9 @@ def create_inputs(prompt, tokenizer, device, backend="cpp"): ...@@ -116,9 +121,9 @@ def create_inputs(prompt, tokenizer, device, backend="cpp"):
return input_ids_infini, position_ids_infini, input_content return input_ids_infini, position_ids_infini, input_content
def run_forward_pass(model, input_ids, position_ids, backend, dtype): def run_forward_pass(model, input_ids, position_ids, backend, dtype, num_decode_steps=2):
"""Run prefill and first decode step with KV cache, return decode step logits.""" """Run prefill and multiple decode steps with KV cache, return all decode step logits."""
print(f" Running forward pass (prefill + first decode step)...") print(f" Running forward pass (prefill + {num_decode_steps} decode step(s))...")
try: try:
# Get the underlying model # Get the underlying model
...@@ -162,19 +167,6 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype): ...@@ -162,19 +167,6 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype):
print( print(
f" Prefill logits stats: min={prefill_logits_np.min():.6f}, max={prefill_logits_np.max():.6f}, mean={prefill_logits_np.mean():.6f}") f" Prefill logits stats: min={prefill_logits_np.min():.6f}, max={prefill_logits_np.max():.6f}, mean={prefill_logits_np.mean():.6f}")
# Step 2: Decode - run forward pass with single token
# Get the predicted token from prefill
if np.isnan(prefill_logits_np).any():
# If prefill has NaN, use a default token to continue testing decode step
print(
f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits")
predicted_token_id = 29902
else:
predicted_token_id = int(
prefill_logits_np.argmax(axis=-1)[0, 0])
print(
f" Step 2: Decode (next_token_id={predicted_token_id})...")
# Get device from input_ids # Get device from input_ids
if hasattr(input_ids, "device"): if hasattr(input_ids, "device"):
input_device = input_ids.device input_device = input_ids.device
...@@ -182,19 +174,59 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype): ...@@ -182,19 +174,59 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype):
input_device = getattr( input_device = getattr(
position_ids, "device", infinicore.device("cpu", 0)) position_ids, "device", infinicore.device("cpu", 0))
# Initialize decode logits list
decode_logits_list = []
seq_len = input_ids.shape[1]
current_token_id = None
# Run multiple decode steps
for decode_step in range(num_decode_steps):
# Get the predicted token from previous step
if decode_step == 0:
# First decode step: use token from prefill
if np.isnan(prefill_logits_np).any():
print(f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits")
current_token_id = 29902
else:
current_token_id = int(prefill_logits_np.argmax(axis=-1)[0, 0])
else:
# Subsequent decode steps: use token from previous decode
prev_logits_np = decode_logits_list[-1]
if np.isnan(prev_logits_np).any():
print(f" ⚠ WARNING: Using default token 29902 due to NaN in decode step {decode_step} logits")
current_token_id = 29902
else:
current_token_id = int(prev_logits_np.argmax(axis=-1)[0, 0])
print(f" Step {decode_step + 2}: Decode step {decode_step + 1} (next_token_id={current_token_id})...")
# Create single token input for decode step # Create single token input for decode step
decode_input_ids = infinicore.from_list( decode_input_ids = infinicore.from_list(
[[predicted_token_id]], device=input_device) [[current_token_id]], device=input_device)
# Create position_ids for decode step (should be seq_len, since we've processed seq_len tokens) # Create position_ids for decode step
seq_len = input_ids.shape[1]
decode_position_ids = infinicore.from_list( decode_position_ids = infinicore.from_list(
[[seq_len]], dtype=infinicore.int64, device=input_device [[seq_len + decode_step]], dtype=infinicore.int64, device=input_device
) )
# Run decode step - C++ backend manages cache internally # Run decode step - C++ backend manages cache internally
decode_logits = underlying_model.forward( decode_logits = underlying_model.forward(
decode_input_ids, decode_position_ids) decode_input_ids, decode_position_ids)
# Convert decode logits to numpy
decode_logits_np = infinicore_to_numpy(decode_logits)
decode_logits_list.append(decode_logits_np)
print(f" ✓ Decode step {decode_step + 1} completed, logits shape: {decode_logits_np.shape}")
# Check decode logits for issues
if np.isnan(decode_logits_np).any():
print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain NaN values!")
print(f" NaN count: {np.isnan(decode_logits_np).sum()}")
if np.isinf(decode_logits_np).any():
print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain Inf values!")
print(f" Inf count: {np.isinf(decode_logits_np).sum()}")
if not np.isnan(decode_logits_np).any():
print(f" Decode step {decode_step + 1} logits stats: min={decode_logits_np.min():.6f}, max={decode_logits_np.max():.6f}, mean={decode_logits_np.mean():.6f}")
else: else:
# Python backend uses DynamicCache # Python backend uses DynamicCache
# Get model config # Get model config
...@@ -217,12 +249,6 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype): ...@@ -217,12 +249,6 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype):
print( print(
f" ✓ Prefill completed, logits shape: {prefill_logits_np.shape}") f" ✓ Prefill completed, logits shape: {prefill_logits_np.shape}")
# Step 2: Decode - run forward pass with single token
# Get the predicted token from prefill
predicted_token_id = int(prefill_logits_np.argmax(axis=-1)[0, 0])
print(
f" Step 2: Decode (next_token_id={predicted_token_id})...")
# Get device from input_ids # Get device from input_ids
if hasattr(input_ids, "device"): if hasattr(input_ids, "device"):
input_device = input_ids.device input_device = input_ids.device
...@@ -231,14 +257,39 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype): ...@@ -231,14 +257,39 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype):
input_device = getattr( input_device = getattr(
position_ids, "device", infinicore.device("cpu", 0)) position_ids, "device", infinicore.device("cpu", 0))
# Initialize decode logits list
decode_logits_list = []
seq_len = input_ids.shape[1]
current_token_id = None
# Run multiple decode steps
for decode_step in range(num_decode_steps):
# Get the predicted token from previous step
if decode_step == 0:
# First decode step: use token from prefill
if np.isnan(prefill_logits_np).any():
print(f" ⚠ WARNING: Using default token 29902 due to NaN in prefill logits")
current_token_id = 29902
else:
current_token_id = int(prefill_logits_np.argmax(axis=-1)[0, 0])
else:
# Subsequent decode steps: use token from previous decode
prev_logits_np = decode_logits_list[-1]
if np.isnan(prev_logits_np).any():
print(f" ⚠ WARNING: Using default token 29902 due to NaN in decode step {decode_step} logits")
current_token_id = 29902
else:
current_token_id = int(prev_logits_np.argmax(axis=-1)[0, 0])
print(f" Step {decode_step + 2}: Decode step {decode_step + 1} (next_token_id={current_token_id})...")
# Create single token input for decode step # Create single token input for decode step
decode_input_ids = infinicore.from_list( decode_input_ids = infinicore.from_list(
[[predicted_token_id]], device=input_device) [[current_token_id]], device=input_device)
# Create position_ids for decode step (should be seq_len, since we've processed seq_len tokens) # Create position_ids for decode step
seq_len = input_ids.shape[1]
decode_position_ids = infinicore.from_list( decode_position_ids = infinicore.from_list(
[[seq_len]], dtype=infinicore.int64, device=input_device [[seq_len + decode_step]], dtype=infinicore.int64, device=input_device
) )
# Run decode step with KV cache # Run decode step with KV cache
...@@ -246,33 +297,47 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype): ...@@ -246,33 +297,47 @@ def run_forward_pass(model, input_ids, position_ids, backend, dtype):
decode_input_ids, decode_position_ids, past_key_values=past_key_values, use_cache=True decode_input_ids, decode_position_ids, past_key_values=past_key_values, use_cache=True
) )
# Convert decode logits to numpy for analysis # Convert decode logits to numpy
logits_np = infinicore_to_numpy(decode_logits) decode_logits_np = infinicore_to_numpy(decode_logits)
decode_logits_list.append(decode_logits_np)
print(f" ✓ Forward pass completed (prefill + decode)") print(f" ✓ Decode step {decode_step + 1} completed, logits shape: {decode_logits_np.shape}")
print(f" Decode logits shape: {logits_np.shape}")
print(f" Decode logits dtype: {logits_np.dtype}") # Check decode logits for issues
print( if np.isnan(decode_logits_np).any():
f" Decode logits stats: min={logits_np.min():.6f}, max={logits_np.max():.6f}, mean={logits_np.mean():.6f}") print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain NaN values!")
print(f" NaN count: {np.isnan(decode_logits_np).sum()}")
# Check for issues if np.isinf(decode_logits_np).any():
print(f" ⚠ WARNING: Decode step {decode_step + 1} logits contain Inf values!")
print(f" Inf count: {np.isinf(decode_logits_np).sum()}")
if not np.isnan(decode_logits_np).any():
print(f" Decode step {decode_step + 1} logits stats: min={decode_logits_np.min():.6f}, max={decode_logits_np.max():.6f}, mean={decode_logits_np.mean():.6f}")
# Summary of all decode steps
print(f" ✓ Forward pass completed (prefill + {num_decode_steps} decode step(s))")
for i, logits_np in enumerate(decode_logits_list):
print(f" Decode step {i + 1} logits shape: {logits_np.shape}, dtype: {logits_np.dtype}")
# Check for issues in all decode steps
has_error = False
for i, logits_np in enumerate(decode_logits_list):
if np.isnan(logits_np).any(): if np.isnan(logits_np).any():
print(f" ⚠ WARNING: Logits contain NaN values!") print(f" ⚠ WARNING: Decode step {i + 1} logits contain NaN values!")
return None, True print(f" NaN count: {np.isnan(logits_np).sum()}")
has_error = True
if np.isinf(logits_np).any(): if np.isinf(logits_np).any():
print(f" ⚠ WARNING: Logits contain Inf values!") print(f" ⚠ WARNING: Decode step {i + 1} logits contain Inf values!")
return None, True print(f" Inf count: {np.isinf(logits_np).sum()}")
has_error = True
# Check if logits are too small (might indicate model not working)
if np.abs(logits_np).max() < 1.0: if np.abs(logits_np).max() < 1.0:
print( print(f" ⚠ WARNING: Decode step {i + 1} logits are very small (max abs: {np.abs(logits_np).max():.6f})")
f" ⚠ WARNING: Logits are very small (max abs: {np.abs(logits_np).max():.6f})")
# Get predicted token from decode step # Get predicted token from last decode step
predicted_token = int(logits_np.argmax(axis=-1)[0, 0]) if decode_logits_list and not np.isnan(decode_logits_list[-1]).any():
print(f" Predicted token ID from decode: {predicted_token}") predicted_token = int(decode_logits_list[-1].argmax(axis=-1)[0, 0])
print(f" Predicted token ID from decode step {num_decode_steps}: {predicted_token}")
return logits_np, False # Return tuple of all decode logits
return tuple(decode_logits_list), has_error
except Exception as e: except Exception as e:
print(f" ✗ Forward pass failed: {e}") print(f" ✗ Forward pass failed: {e}")
...@@ -353,7 +418,7 @@ def infinicore_to_numpy(tensor): ...@@ -353,7 +418,7 @@ def infinicore_to_numpy(tensor):
return result return result
def test_configuration(model_path, device, backend, dtype, prompt): def test_configuration(model_path, device, backend, dtype, prompt, num_decode_steps=2):
"""Test a specific backend/dtype configuration.""" """Test a specific backend/dtype configuration."""
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"Testing: Backend={backend}, Dtype={dtype}") print(f"Testing: Backend={backend}, Dtype={dtype}")
...@@ -377,7 +442,7 @@ def test_configuration(model_path, device, backend, dtype, prompt): ...@@ -377,7 +442,7 @@ def test_configuration(model_path, device, backend, dtype, prompt):
# Load tokenizer # Load tokenizer
print("\n1. Loading tokenizer...") print("\n1. Loading tokenizer...")
try: try:
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print(f" ✓ Tokenizer loaded") print(f" ✓ Tokenizer loaded")
except Exception as e: except Exception as e:
print(f" ✗ Failed to load tokenizer: {e}") print(f" ✗ Failed to load tokenizer: {e}")
...@@ -428,25 +493,25 @@ def test_configuration(model_path, device, backend, dtype, prompt): ...@@ -428,25 +493,25 @@ def test_configuration(model_path, device, backend, dtype, prompt):
traceback.print_exc() traceback.print_exc()
return None, True return None, True
# Run forward pass (prefill + decode step) # Run forward pass (prefill + multiple decode steps)
print(f"\n5. Running forward pass (prefill + first decode step)...") print(f"\n5. Running forward pass (prefill + {num_decode_steps} decode step(s))...")
logits, has_error = run_forward_pass( logits_tuple, has_error = run_forward_pass(
model, input_ids, position_ids, backend, dtype) model, input_ids, position_ids, backend, dtype, num_decode_steps)
if has_error: if has_error:
return None, True return None, True
return logits, False return logits_tuple, False
def compare_logits(logits1, logits2, name1, name2): def compare_logits(logits1, logits2, name1, name2, step_name="logits"):
"""Compare two logits arrays.""" """Compare two logits arrays."""
print(f"\n{'=' * 80}") print(f"\n{'=' * 80}")
print(f"Comparing: {name1} vs {name2}") print(f"Comparing: {name1} vs {name2} ({step_name})")
print(f"{'=' * 80}") print(f"{'=' * 80}")
if logits1 is None or logits2 is None: if logits1 is None or logits2 is None:
print(" ✗ Cannot compare: one or both logits are None") print(f" ✗ Cannot compare: one or both {step_name} are None")
return False return False
if logits1.shape != logits2.shape: if logits1.shape != logits2.shape:
...@@ -469,9 +534,9 @@ def compare_logits(logits1, logits2, name1, name2): ...@@ -469,9 +534,9 @@ def compare_logits(logits1, logits2, name1, name2):
is_close = np.allclose(logits1, logits2, rtol=rtol, atol=atol) is_close = np.allclose(logits1, logits2, rtol=rtol, atol=atol)
if is_close: if is_close:
print(f" ✓ Logits are close (within tolerance)") print(f" ✓ {step_name.capitalize()} are close (within tolerance)")
else: else:
print(f" ⚠ Logits differ significantly") print(f" ⚠ {step_name.capitalize()} differ significantly")
# Show top differences # Show top differences
flat_diff = diff.flatten() flat_diff = diff.flatten()
top_indices = np.argsort(flat_diff)[-10:][::-1] top_indices = np.argsort(flat_diff)[-10:][::-1]
...@@ -493,6 +558,7 @@ def main(): ...@@ -493,6 +558,7 @@ def main():
print(f"Model path: {args.model_path}") print(f"Model path: {args.model_path}")
print(f"Device: {args.device}") print(f"Device: {args.device}")
print(f"Prompt: {args.prompt}") print(f"Prompt: {args.prompt}")
print(f"Number of decode steps: {args.num_decode_steps}")
print("=" * 80) print("=" * 80)
results = {} results = {}
...@@ -502,25 +568,16 @@ def main(): ...@@ -502,25 +568,16 @@ def main():
print("TEST 1: Python Backend + BFloat16") print("TEST 1: Python Backend + BFloat16")
print("=" * 80) print("=" * 80)
logits_py_bf16, error = test_configuration( logits_py_bf16, error = test_configuration(
args.model_path, args.device, "python", "bfloat16", args.prompt args.model_path, args.device, "python", "bfloat16", args.prompt, args.num_decode_steps
) )
results["python_bf16"] = (logits_py_bf16, error) results["python_bf16"] = (logits_py_bf16, error)
# Test 2: C++ backend with float32
print("\n\n" + "=" * 80)
print("TEST 2: C++ Backend + Float32")
print("=" * 80)
logits_cpp_f32, error = test_configuration(
args.model_path, args.device, "cpp", "float32", args.prompt
)
results["cpp_f32"] = (logits_cpp_f32, error)
# Test 3: C++ backend with bfloat16 # Test 3: C++ backend with bfloat16
print("\n\n" + "=" * 80) print("\n\n" + "=" * 80)
print("TEST 3: C++ Backend + BFloat16") print("TEST 3: C++ Backend + BFloat16")
print("=" * 80) print("=" * 80)
logits_cpp_bf16, error = test_configuration( logits_cpp_bf16, error = test_configuration(
args.model_path, args.device, "cpp", "bfloat16", args.prompt args.model_path, args.device, "cpp", "bfloat16", args.prompt, args.num_decode_steps
) )
results["cpp_bf16"] = (logits_cpp_bf16, error) results["cpp_bf16"] = (logits_cpp_bf16, error)
...@@ -533,23 +590,22 @@ def main(): ...@@ -533,23 +590,22 @@ def main():
# Compare Python BF16 vs C++ BF16 (should be similar) # Compare Python BF16 vs C++ BF16 (should be similar)
if not results["python_bf16"][1] and not results["cpp_bf16"][1]: if not results["python_bf16"][1] and not results["cpp_bf16"][1]:
py_logits = results["python_bf16"][0]
cpp_logits = results["cpp_bf16"][0]
if py_logits is not None and cpp_logits is not None:
# Compare all decode steps
num_steps = min(len(py_logits), len(cpp_logits))
for step_idx in range(num_steps):
step_name = f"decode step {step_idx + 1}"
is_close = compare_logits( is_close = compare_logits(
results["python_bf16"][0], py_logits[step_idx],
results["cpp_bf16"][0], cpp_logits[step_idx],
"Python BF16", "Python BF16",
"C++ BF16" "C++ BF16",
) step_name
comparisons.append(("Python BF16 vs C++ BF16", is_close))
# Compare C++ F32 vs C++ BF16 (should be similar but with some differences)
if not results["cpp_f32"][1] and not results["cpp_bf16"][1]:
is_close = compare_logits(
results["cpp_f32"][0],
results["cpp_bf16"][0],
"C++ F32",
"C++ BF16"
) )
comparisons.append(("C++ F32 vs C++ BF16", is_close)) comparisons.append((f"Python BF16 vs C++ BF16 ({step_name})", is_close))
# Summary # Summary
print("\n\n" + "=" * 80) print("\n\n" + "=" * 80)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment