Commit 4d94cd51 authored by jiaruifang's avatar jiaruifang Committed by Frank Lee
Browse files

adapting bert unitest interface

parent 7977422a
......@@ -48,9 +48,21 @@ def get_training_components():
num_hidden_layers=num_layer,
)
print('building BertForSequenceClassification model')
model = BertForSequenceClassification(config)
# adapting huggingface BertForSequenceClassification for single unitest calling interface
class ModelAaptor(BertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable()
return model
trainloader = get_bert_data_loader(batch_size=2,
......
......@@ -31,11 +31,11 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
loss.backward()
def run_bert_fwd_bwd(model, data, label, enable_autocast=False):
# with no criterion
def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
model.train()
with torch.cuda.amp.autocast(enabled=enable_autocast):
output = model(input_ids=data, labels=label)
loss = output[0]
loss = model(data, label)
if isinstance(model, ShardedModelV2):
model.backward(loss)
else:
......@@ -60,8 +60,8 @@ def run_dist(rank, world_size, port):
if model_name == 'bert':
data, label = data.cuda(), label.cuda()
run_bert_fwd_bwd(model, data, label, False)
run_bert_fwd_bwd(zero_model, data, label, False)
run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False)
else:
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment