"git@developer.sourcefind.cn:change/sglang.git" did not exist on "e81d7f11dede2b9b3f82de00a433eccc3d47c25e"
Commit 59031aa7 authored by Neel Kant's avatar Neel Kant
Browse files

more for pretrain_realm

parent 002cb170
...@@ -284,10 +284,11 @@ class REALMBertModel(MegatronModule): ...@@ -284,10 +284,11 @@ class REALMBertModel(MegatronModule):
class REALMRetriever(MegatronModule): class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex""" """Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, hashed_index, top_k=5): def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__() super(REALMRetriever, self).__init__()
self.ict_model = ict_model self.ict_model = ict_model
self.ict_dataset = ict_dataset self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index self.hashed_index = hashed_index
self.top_k = top_k self.top_k = top_k
...@@ -320,8 +321,8 @@ class REALMRetriever(MegatronModule): ...@@ -320,8 +321,8 @@ class REALMRetriever(MegatronModule):
block_buckets[j] = block_buckets[i].copy() block_buckets[j] = block_buckets[i].copy()
# [batch_size x max_bucket_population x embed_size] # [batch_size x max_bucket_population x embed_size]
block_embeds = [torch.cuda.FloatTensor(np.array([self.hashed_index.get_block_embed(arr[3]) block_embeds = [torch.cuda.FloatTensor(np.array([self.block_data.embed_data[idx]
for arr in bucket])) for bucket in block_buckets] for idx in bucket])) for bucket in block_buckets]
all_top5_tokens, all_top5_pad_masks = [], [] all_top5_tokens, all_top5_pad_masks = [], []
for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets): for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets):
......
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from hashed_index import load_ict_checkpoint, get_ict_dataset from hashed_index import load_ict_checkpoint, get_ict_dataset
from megatron.data.realm_dataset import HashedIndex from megatron.data.realm_index import BlockData, RandProjectionLSHIndex
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
...@@ -39,9 +39,10 @@ def model_provider(): ...@@ -39,9 +39,10 @@ def model_provider():
ict_model = load_ict_checkpoint() ict_model = load_ict_checkpoint()
ict_dataset = get_ict_dataset() ict_dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file(args.hash_data_path) all_block_data = BlockData.load_from_file(args.block_data_path)
hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
retriever = REALMRetriever(ict_model, ict_dataset, hashed_index) retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index)
# TODO: REALMBertModel should accept a path to a pretrained bert-base # TODO: REALMBertModel should accept a path to a pretrained bert-base
model = REALMBertModel(retriever) model = REALMBertModel(retriever)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment