"sgl-kernel/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "3bdcdd134b1c05b2c232172fa929652e477542a8"
Unverified Commit b9952499 authored by topduke's avatar topduke Committed by GitHub
Browse files

modify int64 to int32 for windows train

parent 1008376c
...@@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer): ...@@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1) log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal( non_pad_mask = paddle.not_equal(
tgt, paddle.zeros( tgt, paddle.zeros(
tgt.shape, dtype='int64')) tgt.shape, dtype='int32'))
loss = -(one_hot * log_prb).sum(axis=1) loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean() loss = loss.masked_select(non_pad_mask).mean()
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment