"vscode:/vscode.git/clone" did not exist on "f327eb4a27fc9370030d62a9aeea827021efd2f0"
Unverified Commit e71471be authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat: add snapshot testing (#282)

parent f58f0a03
{
"generated_text": ".get(\"action\");\n if (action == null) {\n throw new RuntimeException",
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 10264,
"text": "Test",
"logprob": null
},
{
"id": 8821,
"text": " request",
"logprob": -11.894989
}
],
"tokens": [
{
"id": 17,
"text": ".",
"logprob": -1.8267672,
"special": false
},
{
"id": 1587,
"text": "get",
"logprob": -2.4674969,
"special": false
},
{
"id": 11,
"text": "(",
"logprob": -1.906001,
"special": false
},
{
"id": 5,
"text": "\"",
"logprob": -1.2279545,
"special": false
},
{
"id": 4899,
"text": "action",
"logprob": -4.170299,
"special": false
},
{
"id": 5,
"text": "\"",
"logprob": -0.32478866,
"special": false
},
{
"id": 12,
"text": ")",
"logprob": -1.0773665,
"special": false
},
{
"id": 30,
"text": ";",
"logprob": -0.27640742,
"special": false
},
{
"id": 837,
"text": "\n ",
"logprob": -1.6970354,
"special": false
},
{
"id": 1320,
"text": " if",
"logprob": -1.4495516,
"special": false
},
{
"id": 375,
"text": " (",
"logprob": -0.23609057,
"special": false
},
{
"id": 4899,
"text": "action",
"logprob": -1.1916996,
"special": false
},
{
"id": 3535,
"text": " ==",
"logprob": -0.8918753,
"special": false
},
{
"id": 5109,
"text": " null",
"logprob": -0.3933342,
"special": false
},
{
"id": 12,
"text": ")",
"logprob": -0.43212673,
"special": false
},
{
"id": 731,
"text": " {",
"logprob": -0.17702064,
"special": false
},
{
"id": 1260,
"text": "\n ",
"logprob": -0.07027565,
"special": false
},
{
"id": 10519,
"text": " throw",
"logprob": -1.3915029,
"special": false
},
{
"id": 2084,
"text": " new",
"logprob": -0.04201372,
"special": false
},
{
"id": 150858,
"text": " RuntimeException",
"logprob": -1.7329919,
"special": false
}
]
}
}
\ No newline at end of file
use float_eq::assert_float_eq;
use serde::Deserialize;
use serde_json::Value;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::thread;
use std::thread::sleep;
use std::time::Duration;
use subprocess::{Popen, PopenConfig, Redirection};
#[derive(Deserialize)]
pub struct Token {
id: u32,
text: String,
logprob: Option<f32>,
special: bool,
}
#[derive(Deserialize)]
struct Details {
finish_reason: String,
generated_tokens: u32,
tokens: Vec<Token>,
}
#[derive(Deserialize)]
struct GeneratedText {
generated_text: String,
details: Details,
}
fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen {
let argv = vec![
"text-generation-launcher".to_string(),
"--model-id".to_string(),
model_id.clone(),
"--num-shard".to_string(),
num_shard.to_string(),
"--port".to_string(),
port.to_string(),
"--master-port".to_string(),
master_port.to_string(),
"--shard-uds-path".to_string(),
format!("/tmp/test-{}-{}-{}", num_shard, port, master_port),
];
let mut launcher = Popen::create(
&argv,
PopenConfig {
stdout: Redirection::Pipe,
stderr: Redirection::Merge,
..Default::default()
},
)
.expect("Could not start launcher");
// Redirect STDOUT and STDERR to the console
// (STDERR is merged into STDOUT)
let launcher_stdout = launcher.stdout.take().unwrap();
thread::spawn(move || {
let stdout = BufReader::new(launcher_stdout);
for line in stdout.lines() {
println!("{}", line.unwrap());
}
});
for _ in 0..60 {
let health = reqwest::blocking::get(format!("http://localhost:{}/health", port));
if health.is_ok() {
return launcher;
}
sleep(Duration::from_secs(2));
}
launcher.terminate().unwrap();
launcher.wait().unwrap();
panic!("failed to launch {}", model_id)
}
fn test_model(
model_id: String,
num_shard: usize,
port: usize,
master_port: usize,
) -> GeneratedText {
let mut launcher = start_launcher(model_id, num_shard, port, master_port);
let data = r#"
{
"inputs": "Test request",
"parameters": {
"details": true
}
}"#;
let req: Value = serde_json::from_str(data).unwrap();
let client = reqwest::blocking::Client::new();
let res = client
.post(format!("http://localhost:{}/generate", port))
.json(&req)
.send();
launcher.terminate().unwrap();
launcher.wait().unwrap();
let result: GeneratedText = res.unwrap().json().unwrap();
result
}
fn read_json(name: &str) -> GeneratedText {
let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
d.push("tests/");
d.push(name);
let file = File::open(d).unwrap();
let reader = BufReader::new(file);
let result: GeneratedText = serde_json::from_reader(reader).unwrap();
result
}
fn compare_results(result: GeneratedText, expected: GeneratedText) {
assert_eq!(result.generated_text, expected.generated_text);
assert_eq!(result.details.finish_reason, expected.details.finish_reason);
assert_eq!(
result.details.generated_tokens,
expected.details.generated_tokens
);
for (token, expected_token) in result
.details
.tokens
.into_iter()
.zip(expected.details.tokens.into_iter())
{
assert_eq!(token.id, expected_token.id);
assert_eq!(token.text, expected_token.text);
assert_eq!(token.special, expected_token.special);
if let Some(logprob) = token.logprob {
let expected_logprob = expected_token.logprob.unwrap();
assert_float_eq!(logprob, expected_logprob, abs <= 0.001);
} else {
assert_eq!(token.logprob, expected_token.logprob);
}
}
}
#[test]
fn test_bloom_560m() {
let expected = read_json("bloom_560m.json");
let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500);
compare_results(result, expected);
}
#[test]
fn test_bloom_560m_distributed() {
let expected = read_json("bloom_560m.json");
let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501);
compare_results(result, expected);
}
#[test]
fn test_mt0_base() {
let expected = read_json("mt0_base.json");
let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502);
compare_results(result, expected);
}
{
"generated_text": "\"\"\"Test the contents of the contents of the contents. \"\"\" test_test",
"details": {
"finish_reason": "length",
"generated_tokens": 20,
"seed": null,
"prefill": [
{
"id": 0,
"text": "<pad>",
"logprob": null
}
],
"tokens": [
{
"id": 259,
"text": "",
"logprob": -1.3656927,
"special": false
},
{
"id": 215100,
"text": "\"\"\"",
"logprob": -2.6551573,
"special": false
},
{
"id": 46138,
"text": "Test",
"logprob": -1.8059857,
"special": false
},
{
"id": 287,
"text": " the",
"logprob": -1.2102449,
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -1.6057279,
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -3.6060903,
"special": false
},
{
"id": 304,
"text": " of",
"logprob": -0.5270343,
"special": false
},
{
"id": 287,
"text": " the",
"logprob": -0.62522805,
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -1.4069618,
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -2.621994,
"special": false
},
{
"id": 304,
"text": " of",
"logprob": -1.3172221,
"special": false
},
{
"id": 287,
"text": " the",
"logprob": -0.3501925,
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -0.7219573,
"special": false
},
{
"id": 49076,
"text": "contents",
"logprob": -1.0494149,
"special": false
},
{
"id": 260,
"text": ".",
"logprob": -1.0803378,
"special": false
},
{
"id": 259,
"text": " ",
"logprob": -0.32933083,
"special": false
},
{
"id": 215100,
"text": "\"\"\"",
"logprob": -0.11268901,
"special": false
},
{
"id": 2978,
"text": " test",
"logprob": -1.5846587,
"special": false
},
{
"id": 290,
"text": "_",
"logprob": -0.49796978,
"special": false
},
{
"id": 4125,
"text": "test",
"logprob": -2.0026445,
"special": false
}
]
}
}
\ No newline at end of file
......@@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
full_name = f"transformer.{name}"
......
......@@ -21,16 +21,14 @@
import torch
import torch.distributed
from torch.nn import functional as F
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
# Flash attention imports
import flash_attn_cuda
import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
......@@ -332,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self, load_in_8bit: bool = False):
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx()
for layer in self.layers:
layer: FlashLlamaLayer
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
layer.mlp.down_proj.prepare_weights(load_in_8bit)
layer.self_attn.query_key_value.prepare_weights(quantize)
layer.self_attn.o_proj.prepare_weights(quantize)
layer.mlp.gate_up_proj.prepare_weights(quantize)
layer.mlp.down_proj.prepare_weights(quantize)
def forward(
self,
......@@ -429,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, load_in_8bit: bool = False):
self.model.post_load_weights(load_in_8bit)
def post_load_weights(self, quantize: Optional[str] = None):
self.model.post_load_weights(quantize)
self.lm_head.prepare_weights()
def forward(
......
......@@ -21,8 +21,6 @@
import torch
import torch.distributed
from torch.nn import functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
......@@ -32,7 +30,6 @@ from typing import Optional
# Flash attention imports
import flash_attn_cuda
from flash_attn.layers.rotary import RotaryEmbedding
from text_generation_server.utils.layers import (
FastLinear,
TensorParallelRowLinear,
......@@ -345,16 +342,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self, load_in_8bit=False):
def post_load_weights(self, quantize: Optional[str] = None):
if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx()
for layer in self.layers:
layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.prepare_weights(load_in_8bit)
layer.attention.dense.prepare_weights(load_in_8bit)
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
layer.attention.query_key_value.prepare_weights(quantize)
layer.attention.dense.prepare_weights(quantize)
layer.mlp.dense_h_to_4h.prepare_weights(quantize)
layer.mlp.dense_4h_to_h.prepare_weights(quantize)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
......@@ -457,8 +454,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config.hidden_size, config.vocab_size, bias=False
)
def post_load_weights(self, load_in_8bit=False):
self.gpt_neox.post_load_weights(load_in_8bit)
def post_load_weights(self, quantize: Optional[str] = None):
self.gpt_neox.post_load_weights(quantize)
self.embed_out.prepare_weights()
@classmethod
......
import torch
import torch.distributed
import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from typing import Optional
......@@ -261,16 +259,16 @@ class FlashSantacoderModel(nn.Module):
self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self, load_in_8bit: bool = False):
def post_load_weights(self, quantize: Optional[str] = None):
if self.tp_embeddings:
self.wte.add_null_idx()
self.wpe.add_null_idx()
for layer in self.h:
layer: Block
layer.attn.c_attn.prepare_weights(load_in_8bit)
layer.attn.c_proj.prepare_weights(load_in_8bit)
layer.mlp.c_fc.prepare_weights(load_in_8bit)
layer.mlp.c_proj.prepare_weights(load_in_8bit)
layer.attn.c_attn.prepare_weights(quantize)
layer.attn.c_proj.prepare_weights(quantize)
layer.mlp.c_fc.prepare_weights(quantize)
layer.mlp.c_proj.prepare_weights(quantize)
def forward(
self,
......@@ -347,8 +345,8 @@ class FlashSantacoderForCausalLM(nn.Module):
else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self, load_in_8bit: bool = False):
self.transformer.post_load_weights(load_in_8bit)
def post_load_weights(self, quantize: Optional[str] = None):
self.transformer.post_load_weights(quantize)
self.lm_head.prepare_weights()
def forward(
......
......@@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__)
class FlashLlama(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16
......@@ -72,14 +77,14 @@ class FlashLlama(FlashCausalLM):
def load_weights(
model,
filenames: List[Path],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device if not quantize else "cpu").to(dtype)
value = value.to(device if quantize is None else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4])
......@@ -199,7 +204,7 @@ class FlashLlamaSharded(FlashLlama):
def load_weights(
model,
filenames: List[str],
quantize: bool,
quantize: Optional[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
......@@ -207,7 +212,7 @@ class FlashLlamaSharded(FlashLlama):
):
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
slice_ = f.get_slice(name)
......
......@@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__)
class FlashNeoX(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
super(FlashNeoX, self).__init__(
FlashGPTNeoXForCausalLM, model_id, revision, quantize
)
......@@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM):
class FlashNeoXSharded(FlashNeoX):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
):
self.process_group, rank, world_size = initialize_torch_distributed()
if torch.cuda.is_available():
......@@ -89,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
......
......@@ -255,7 +255,7 @@ class GalacticaSharded(Galactica):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
if name == "lm_head.weight":
......
......@@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if quantize is None else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment