"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a51b6cc86a5bef62283562d49497a4f3e0b134d8"
Unverified Commit 8ed7d97b authored by Mitchell Goff's avatar Mitchell Goff Committed by GitHub
Browse files

Update create_dynamic_map to always return a float32 tensor (#1521)

parent 86b6c37a
......@@ -389,14 +389,14 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
if signed
else 2 ** (i + non_sign_bits - max_exponent_bits + 1) + 1,
)
boundaries = torch.linspace(0.1, 1, fraction_items)
boundaries = torch.linspace(0.1, 1, fraction_items, dtype=torch.float32)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
data += (-(10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if additional_items > 0:
boundaries = torch.linspace(0.1, 1, additional_items + 1)
boundaries = torch.linspace(0.1, 1, additional_items + 1, dtype=torch.float32)
means = (boundaries[:-1] + boundaries[1:]) / 2.0
data += ((10 ** (-(max_exponent_bits - 1) + i)) * means).tolist()
if signed:
......@@ -412,7 +412,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0)
data.sort()
return torch.tensor(data)
return torch.tensor(data, dtype=torch.float32)
def create_quantile_map(A, total_bits=8):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment