import jax import pymc as pm import numpy as np import pytensor as pt pt.config.floatX = "float32" np.random.seed(123) n =10000 X = np.random.randn(n) Y =3* X + np.random.randn(n) # 定义PyMC3模型 with pm.Model() as model: alpha = pm.Normal('alpha', mu=0, sigma=10) beta = pm.Normal('beta', mu=0, sigma=10) sigma = pm.HalfNormal('sigma', sigma=1) mu = alpha + beta * X Y_obs = pm.Normal('Y_obs', mu=mu, sigma=sigma, observed=Y) trace = pm.sample(1000, nuts_sampler="numpyro",return_inferencedata=False) print(pm.summary(trace))