"vscode:/vscode.git/clone" did not exist on "551ebc91c72184975bdec751820487ebe7326370"
Commit a0368ddf authored by Raul Puri's avatar Raul Puri
Browse files

eval+numeric update

parent abe36e2e
...@@ -34,10 +34,7 @@ from model import DistributedDataParallel as DDP ...@@ -34,10 +34,7 @@ from model import DistributedDataParallel as DDP
import mpu import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from utils import Timers from utils import Timers
from utils import save_checkpoint
from utils import save_checkpoint_model_parallel
from utils import load_checkpoint from utils import load_checkpoint
from utils import load_checkpoint_model_parallel
from utils import report_memory from utils import report_memory
from utils import print_params_min_max_norm from utils import print_params_min_max_norm
from utils import print_rank_0 from utils import print_rank_0
...@@ -84,7 +81,7 @@ def setup_model(args): ...@@ -84,7 +81,7 @@ def setup_model(args):
model = get_model(args) model = get_model(args)
if args.load is not None: if args.load is not None:
_ = load_checkpoint_model_parallel( _ = load_checkpoint(
model, None, None, args) model, None, None, args)
return model return model
......
...@@ -60,6 +60,17 @@ def make_gpt2_dataloaders(args): ...@@ -60,6 +60,17 @@ def make_gpt2_dataloaders(args):
valid = make_data_loader_(args.val_data_path) valid = make_data_loader_(args.val_data_path)
test = make_data_loader_(args.test_data_path) test = make_data_loader_(args.test_data_path)
args.do_train = False
args.do_valid = False
args.do_test = False
if train is not None:
args.do_train = True
if valid is not None:
args.do_valid = True
if test is not None:
args.do_test = True
# Tokenizer. # Tokenizer.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir) tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=args.cache_dir)
eod_token = tokenizer.encoder['<|endoftext|>'] eod_token = tokenizer.encoder['<|endoftext|>']
...@@ -126,7 +137,8 @@ class GPT2Dataset(Dataset): ...@@ -126,7 +137,8 @@ class GPT2Dataset(Dataset):
def build_dataset_(self, shard_index): def build_dataset_(self, shard_index):
# Garbage collect so we don't use a lot of memory. # Garbage collect so we don't use a lot of memory.
# Leave the last one in case other threads have not catche up yet. # Leave the last one in case other threads have not catche up yet.
for i in range(shard_index - 1): #for i in range(shard_index - 1):
for i in range(shard_index):
self.shards_data[i] = None self.shards_data[i] = None
self.shards_sample_index[i] = None self.shards_sample_index[i] = None
# Read the shard. # Read the shard.
......
...@@ -480,10 +480,9 @@ class BertParallelSelfAttention(torch.nn.Module): ...@@ -480,10 +480,9 @@ class BertParallelSelfAttention(torch.nn.Module):
value_layer = self._transpose_for_scores(mixed_value_layer) value_layer = self._transpose_for_scores(mixed_value_layer)
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer, norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
key_layer.transpose(-1, -2)) attention_scores = torch.matmul(query_layer/norm_factor,
attention_scores = attention_scores / math.sqrt( key_layer.transpose(-1, -2)/norm_factor)
self.hidden_size_per_attention_head)
# Apply the attention mask. # Apply the attention mask.
attention_scores += attention_mask attention_scores += attention_mask
......
...@@ -221,7 +221,6 @@ def forward_step(data_iterator, model, args, timers): ...@@ -221,7 +221,6 @@ def forward_step(data_iterator, model, args, timers):
losses = mpu.vocab_parallel_cross_entropy( losses = mpu.vocab_parallel_cross_entropy(
output.contiguous().float(), lm_labels.contiguous()) output.contiguous().float(), lm_labels.contiguous())
loss_mask = loss_mask.contiguous() loss_mask = loss_mask.contiguous()
loss_mask = loss_mask.view(-1)
lm_loss = torch.sum( lm_loss = torch.sum(
losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum() losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment