use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::models::bert::{BertModel, Config}; use candle_wasm_example_bert::console_log; use tokenizers::{PaddingParams, Tokenizer}; use wasm_bindgen::prelude::*; #[wasm_bindgen] pub struct Model { bert: BertModel, tokenizer: Tokenizer, } #[wasm_bindgen] impl Model { #[wasm_bindgen(constructor)] pub fn load(weights: Vec, tokenizer: Vec, config: Vec) -> Result { console_error_panic_hook::set_once(); console_log!("loading model"); let device = &Device::Cpu; let vb = VarBuilder::from_buffered_safetensors(weights, DType::F64, device)?; let config: Config = serde_json::from_slice(&config)?; let tokenizer = Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; let bert = BertModel::load(vb, &config)?; Ok(Self { bert, tokenizer }) } pub fn get_embeddings(&mut self, input: JsValue) -> Result { let input: Params = serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; let sentences = input.sentences; let normalize_embeddings = input.normalize_embeddings; let device = &Device::Cpu; if let Some(pp) = self.tokenizer.get_padding_mut() { pp.strategy = tokenizers::PaddingStrategy::BatchLongest } else { let pp = PaddingParams { strategy: tokenizers::PaddingStrategy::BatchLongest, ..Default::default() }; self.tokenizer.with_padding(Some(pp)); } let tokens = self .tokenizer .encode_batch(sentences.to_vec(), true) .map_err(|m| JsError::new(&m.to_string()))?; let token_ids: Vec = tokens .iter() .map(|tokens| { let tokens = tokens.get_ids().to_vec(); Tensor::new(tokens.as_slice(), device) }) .collect::, _>>()?; let token_ids = Tensor::stack(&token_ids, 0)?; let token_type_ids = token_ids.zeros_like()?; console_log!("running inference on batch {:?}", token_ids.shape()); let embeddings = self.bert.forward(&token_ids, &token_type_ids)?; console_log!("generated embeddings {:?}", embeddings.shape()); // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?; let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?; let embeddings = if normalize_embeddings { embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)? } else { embeddings }; let embeddings_data = embeddings.to_vec2()?; Ok(serde_wasm_bindgen::to_value(&Embeddings { data: embeddings_data, })?) } } #[derive(serde::Serialize, serde::Deserialize)] struct Embeddings { data: Vec>, } #[derive(serde::Serialize, serde::Deserialize)] pub struct Params { sentences: Vec, normalize_embeddings: bool, } fn main() { console_error_panic_hook::set_once(); }